Layers.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. from torch.nn import functional as F
  6. from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
  7. class FFTBlock(torch.nn.Module):
  8. """FFT Block"""
  9. def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
  10. super(FFTBlock, self).__init__()
  11. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  12. self.pos_ffn = PositionwiseFeedForward(
  13. d_model, d_inner, kernel_size, dropout=dropout
  14. )
  15. def forward(self, enc_input, mask=None, slf_attn_mask=None):
  16. enc_output, enc_slf_attn = self.slf_attn(
  17. enc_input, enc_input, enc_input, mask=slf_attn_mask
  18. )
  19. enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
  20. enc_output = self.pos_ffn(enc_output)
  21. enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
  22. return enc_output, enc_slf_attn
  23. class ConvNorm(torch.nn.Module):
  24. def __init__(
  25. self,
  26. in_channels,
  27. out_channels,
  28. kernel_size=1,
  29. stride=1,
  30. padding=None,
  31. dilation=1,
  32. bias=True,
  33. w_init_gain="linear",
  34. ):
  35. super(ConvNorm, self).__init__()
  36. if padding is None:
  37. assert kernel_size % 2 == 1
  38. padding = int(dilation * (kernel_size - 1) / 2)
  39. self.conv = torch.nn.Conv1d(
  40. in_channels,
  41. out_channels,
  42. kernel_size=kernel_size,
  43. stride=stride,
  44. padding=padding,
  45. dilation=dilation,
  46. bias=bias,
  47. )
  48. def forward(self, signal):
  49. conv_signal = self.conv(signal)
  50. return conv_signal
  51. class PostNet(nn.Module):
  52. """
  53. PostNet: Five 1-d convolution with 512 channels and kernel size 5
  54. """
  55. def __init__(
  56. self,
  57. n_mel_channels=80,
  58. postnet_embedding_dim=512,
  59. postnet_kernel_size=5,
  60. postnet_n_convolutions=5,
  61. ):
  62. super(PostNet, self).__init__()
  63. self.convolutions = nn.ModuleList()
  64. self.convolutions.append(
  65. nn.Sequential(
  66. ConvNorm(
  67. n_mel_channels,
  68. postnet_embedding_dim,
  69. kernel_size=postnet_kernel_size,
  70. stride=1,
  71. padding=int((postnet_kernel_size - 1) / 2),
  72. dilation=1,
  73. w_init_gain="tanh",
  74. ),
  75. nn.BatchNorm1d(postnet_embedding_dim),
  76. )
  77. )
  78. for i in range(1, postnet_n_convolutions - 1):
  79. self.convolutions.append(
  80. nn.Sequential(
  81. ConvNorm(
  82. postnet_embedding_dim,
  83. postnet_embedding_dim,
  84. kernel_size=postnet_kernel_size,
  85. stride=1,
  86. padding=int((postnet_kernel_size - 1) / 2),
  87. dilation=1,
  88. w_init_gain="tanh",
  89. ),
  90. nn.BatchNorm1d(postnet_embedding_dim),
  91. )
  92. )
  93. self.convolutions.append(
  94. nn.Sequential(
  95. ConvNorm(
  96. postnet_embedding_dim,
  97. n_mel_channels,
  98. kernel_size=postnet_kernel_size,
  99. stride=1,
  100. padding=int((postnet_kernel_size - 1) / 2),
  101. dilation=1,
  102. w_init_gain="linear",
  103. ),
  104. nn.BatchNorm1d(n_mel_channels),
  105. )
  106. )
  107. def forward(self, x):
  108. x = x.contiguous().transpose(1, 2)
  109. for i in range(len(self.convolutions) - 1):
  110. x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
  111. x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
  112. x = x.contiguous().transpose(1, 2)
  113. return x