123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import numpy as np
- from torch.nn import functional as F
- from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
- class FFTBlock(torch.nn.Module):
- """FFT Block"""
- def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
- super(FFTBlock, self).__init__()
- self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForward(
- d_model, d_inner, kernel_size, dropout=dropout
- )
- def forward(self, enc_input, mask=None, slf_attn_mask=None):
- enc_output, enc_slf_attn = self.slf_attn(
- enc_input, enc_input, enc_input, mask=slf_attn_mask
- )
- enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
- enc_output = self.pos_ffn(enc_output)
- enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
- return enc_output, enc_slf_attn
- class ConvNorm(torch.nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=None,
- dilation=1,
- bias=True,
- w_init_gain="linear",
- ):
- super(ConvNorm, self).__init__()
- if padding is None:
- assert kernel_size % 2 == 1
- padding = int(dilation * (kernel_size - 1) / 2)
- self.conv = torch.nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias,
- )
- def forward(self, signal):
- conv_signal = self.conv(signal)
- return conv_signal
- class PostNet(nn.Module):
- """
- PostNet: Five 1-d convolution with 512 channels and kernel size 5
- """
- def __init__(
- self,
- n_mel_channels=80,
- postnet_embedding_dim=512,
- postnet_kernel_size=5,
- postnet_n_convolutions=5,
- ):
- super(PostNet, self).__init__()
- self.convolutions = nn.ModuleList()
- self.convolutions.append(
- nn.Sequential(
- ConvNorm(
- n_mel_channels,
- postnet_embedding_dim,
- kernel_size=postnet_kernel_size,
- stride=1,
- padding=int((postnet_kernel_size - 1) / 2),
- dilation=1,
- w_init_gain="tanh",
- ),
- nn.BatchNorm1d(postnet_embedding_dim),
- )
- )
- for i in range(1, postnet_n_convolutions - 1):
- self.convolutions.append(
- nn.Sequential(
- ConvNorm(
- postnet_embedding_dim,
- postnet_embedding_dim,
- kernel_size=postnet_kernel_size,
- stride=1,
- padding=int((postnet_kernel_size - 1) / 2),
- dilation=1,
- w_init_gain="tanh",
- ),
- nn.BatchNorm1d(postnet_embedding_dim),
- )
- )
- self.convolutions.append(
- nn.Sequential(
- ConvNorm(
- postnet_embedding_dim,
- n_mel_channels,
- kernel_size=postnet_kernel_size,
- stride=1,
- padding=int((postnet_kernel_size - 1) / 2),
- dilation=1,
- w_init_gain="linear",
- ),
- nn.BatchNorm1d(n_mel_channels),
- )
- )
- def forward(self, x):
- x = x.contiguous().transpose(1, 2)
- for i in range(len(self.convolutions) - 1):
- x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
- x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
- x = x.contiguous().transpose(1, 2)
- return x
|