Modules.py 598 B

12345678910111213141516171819202122232425
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. class ScaledDotProductAttention(nn.Module):
  5. """ Scaled Dot-Product Attention """
  6. def __init__(self, temperature):
  7. super().__init__()
  8. self.temperature = temperature
  9. self.softmax = nn.Softmax(dim=2)
  10. def forward(self, q, k, v, mask=None):
  11. attn = torch.bmm(q, k.transpose(1, 2))
  12. attn = attn / self.temperature
  13. if mask is not None:
  14. attn = attn.masked_fill(mask, -np.inf)
  15. attn = self.softmax(attn)
  16. output = torch.bmm(attn, v)
  17. return output, attn