Models.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import transformer.Constants as Constants
  5. from .Layers import FFTBlock
  6. from text.symbols import symbols
  7. def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
  8. """ Sinusoid position encoding table """
  9. def cal_angle(position, hid_idx):
  10. return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
  11. def get_posi_angle_vec(position):
  12. return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
  13. sinusoid_table = np.array(
  14. [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
  15. )
  16. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  17. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  18. if padding_idx is not None:
  19. # zero vector for padding dimension
  20. sinusoid_table[padding_idx] = 0.0
  21. return torch.FloatTensor(sinusoid_table)
  22. class Encoder(nn.Module):
  23. """ Encoder """
  24. def __init__(self, config):
  25. super(Encoder, self).__init__()
  26. n_position = config["max_seq_len"] + 1
  27. n_src_vocab = len(symbols) + 1
  28. d_word_vec = config["transformer"]["encoder_hidden"]
  29. n_layers = config["transformer"]["encoder_layer"]
  30. n_head = config["transformer"]["encoder_head"]
  31. d_k = d_v = (
  32. config["transformer"]["encoder_hidden"]
  33. // config["transformer"]["encoder_head"]
  34. )
  35. d_model = config["transformer"]["encoder_hidden"]
  36. d_inner = config["transformer"]["conv_filter_size"]
  37. kernel_size = config["transformer"]["conv_kernel_size"]
  38. dropout = config["transformer"]["encoder_dropout"]
  39. self.max_seq_len = config["max_seq_len"]
  40. self.d_model = d_model
  41. self.src_word_emb = nn.Embedding(
  42. n_src_vocab, d_word_vec, padding_idx=Constants.PAD
  43. )
  44. self.position_enc = nn.Parameter(
  45. get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
  46. requires_grad=False,
  47. )
  48. self.layer_stack = nn.ModuleList(
  49. [
  50. FFTBlock(
  51. d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
  52. )
  53. for _ in range(n_layers)
  54. ]
  55. )
  56. def forward(self, src_seq, mask, return_attns=False):
  57. enc_slf_attn_list = []
  58. batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
  59. # -- Prepare masks
  60. slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
  61. # -- Forward
  62. if not self.training and src_seq.shape[1] > self.max_seq_len:
  63. enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table(
  64. src_seq.shape[1], self.d_model
  65. )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
  66. src_seq.device
  67. )
  68. else:
  69. enc_output = self.src_word_emb(src_seq) + self.position_enc[
  70. :, :max_len, :
  71. ].expand(batch_size, -1, -1)
  72. for enc_layer in self.layer_stack:
  73. enc_output, enc_slf_attn = enc_layer(
  74. enc_output, mask=mask, slf_attn_mask=slf_attn_mask
  75. )
  76. if return_attns:
  77. enc_slf_attn_list += [enc_slf_attn]
  78. return enc_output
  79. class Decoder(nn.Module):
  80. """ Decoder """
  81. def __init__(self, config):
  82. super(Decoder, self).__init__()
  83. n_position = config["max_seq_len"] + 1
  84. d_word_vec = config["transformer"]["decoder_hidden"]
  85. n_layers = config["transformer"]["decoder_layer"]
  86. n_head = config["transformer"]["decoder_head"]
  87. d_k = d_v = (
  88. config["transformer"]["decoder_hidden"]
  89. // config["transformer"]["decoder_head"]
  90. )
  91. d_model = config["transformer"]["decoder_hidden"]
  92. d_inner = config["transformer"]["conv_filter_size"]
  93. kernel_size = config["transformer"]["conv_kernel_size"]
  94. dropout = config["transformer"]["decoder_dropout"]
  95. self.max_seq_len = config["max_seq_len"]
  96. self.d_model = d_model
  97. self.position_enc = nn.Parameter(
  98. get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
  99. requires_grad=False,
  100. )
  101. self.layer_stack = nn.ModuleList(
  102. [
  103. FFTBlock(
  104. d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
  105. )
  106. for _ in range(n_layers)
  107. ]
  108. )
  109. def forward(self, enc_seq, mask, return_attns=False):
  110. dec_slf_attn_list = []
  111. batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
  112. # -- Forward
  113. if not self.training and enc_seq.shape[1] > self.max_seq_len:
  114. # -- Prepare masks
  115. slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
  116. dec_output = enc_seq + get_sinusoid_encoding_table(
  117. enc_seq.shape[1], self.d_model
  118. )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
  119. enc_seq.device
  120. )
  121. else:
  122. max_len = min(max_len, self.max_seq_len)
  123. # -- Prepare masks
  124. slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
  125. dec_output = enc_seq[:, :max_len, :] + self.position_enc[
  126. :, :max_len, :
  127. ].expand(batch_size, -1, -1)
  128. mask = mask[:, :max_len]
  129. slf_attn_mask = slf_attn_mask[:, :, :max_len]
  130. for dec_layer in self.layer_stack:
  131. dec_output, dec_slf_attn = dec_layer(
  132. dec_output, mask=mask, slf_attn_mask=slf_attn_mask
  133. )
  134. if return_attns:
  135. dec_slf_attn_list += [dec_slf_attn]
  136. return dec_output, mask