123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- import os
- import json
- import copy
- import math
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import numpy as np
- import torch.nn.functional as F
- from utils.tools import get_mask_from_lengths, pad
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- class VarianceAdaptor(nn.Module):
- """Variance Adaptor"""
- def __init__(self, preprocess_config, model_config):
- super(VarianceAdaptor, self).__init__()
- self.duration_predictor = VariancePredictor(model_config)
- self.length_regulator = LengthRegulator()
- self.pitch_predictor = VariancePredictor(model_config)
- self.energy_predictor = VariancePredictor(model_config)
- self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
- "feature"
- ]
- self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
- "feature"
- ]
- assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
- assert self.energy_feature_level in ["phoneme_level", "frame_level"]
- pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
- energy_quantization = model_config["variance_embedding"]["energy_quantization"]
- n_bins = model_config["variance_embedding"]["n_bins"]
- assert pitch_quantization in ["linear", "log"]
- assert energy_quantization in ["linear", "log"]
- with open(
- os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
- ) as f:
- stats = json.load(f)
- pitch_min, pitch_max = stats["pitch"][:2]
- energy_min, energy_max = stats["energy"][:2]
- if pitch_quantization == "log":
- self.pitch_bins = nn.Parameter(
- torch.exp(
- torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
- ),
- requires_grad=False,
- )
- else:
- self.pitch_bins = nn.Parameter(
- torch.linspace(pitch_min, pitch_max, n_bins - 1),
- requires_grad=False,
- )
- if energy_quantization == "log":
- self.energy_bins = nn.Parameter(
- torch.exp(
- torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
- ),
- requires_grad=False,
- )
- else:
- self.energy_bins = nn.Parameter(
- torch.linspace(energy_min, energy_max, n_bins - 1),
- requires_grad=False,
- )
- self.pitch_embedding = nn.Embedding(
- n_bins, model_config["transformer"]["encoder_hidden"]
- )
- self.energy_embedding = nn.Embedding(
- n_bins, model_config["transformer"]["encoder_hidden"]
- )
- def get_pitch_embedding(self, x, target, mask, control):
- prediction = self.pitch_predictor(x, mask)
- if target is not None:
- embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
- else:
- prediction = prediction * control
- embedding = self.pitch_embedding(
- torch.bucketize(prediction, self.pitch_bins)
- )
- return prediction, embedding
- def get_energy_embedding(self, x, target, mask, control):
- prediction = self.energy_predictor(x, mask)
- if target is not None:
- embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
- else:
- prediction = prediction * control
- embedding = self.energy_embedding(
- torch.bucketize(prediction, self.energy_bins)
- )
- return prediction, embedding
- def forward(
- self,
- x,
- src_mask,
- mel_mask=None,
- max_len=None,
- pitch_target=None,
- energy_target=None,
- duration_target=None,
- p_control=1.0,
- e_control=1.0,
- d_control=1.0,
- ):
- log_duration_prediction = self.duration_predictor(x, src_mask)
- if self.pitch_feature_level == "phoneme_level":
- pitch_prediction, pitch_embedding = self.get_pitch_embedding(
- x, pitch_target, src_mask, p_control
- )
- x = x + pitch_embedding
- if self.energy_feature_level == "phoneme_level":
- energy_prediction, energy_embedding = self.get_energy_embedding(
- x, energy_target, src_mask, p_control
- )
- x = x + energy_embedding
- if duration_target is not None:
- x, mel_len = self.length_regulator(x, duration_target, max_len)
- duration_rounded = duration_target
- else:
- duration_rounded = torch.clamp(
- (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
- min=0,
- )
- x, mel_len = self.length_regulator(x, duration_rounded, max_len)
- mel_mask = get_mask_from_lengths(mel_len)
- if self.pitch_feature_level == "frame_level":
- pitch_prediction, pitch_embedding = self.get_pitch_embedding(
- x, pitch_target, mel_mask, p_control
- )
- x = x + pitch_embedding
- if self.energy_feature_level == "frame_level":
- energy_prediction, energy_embedding = self.get_energy_embedding(
- x, energy_target, mel_mask, p_control
- )
- x = x + energy_embedding
- return (
- x,
- pitch_prediction,
- energy_prediction,
- log_duration_prediction,
- duration_rounded,
- mel_len,
- mel_mask,
- )
- class LengthRegulator(nn.Module):
- """Length Regulator"""
- def __init__(self):
- super(LengthRegulator, self).__init__()
- def LR(self, x, duration, max_len):
- output = list()
- mel_len = list()
- for batch, expand_target in zip(x, duration):
- expanded = self.expand(batch, expand_target)
- output.append(expanded)
- mel_len.append(expanded.shape[0])
- if max_len is not None:
- output = pad(output, max_len)
- else:
- output = pad(output)
- return output, torch.LongTensor(mel_len).to(device)
- def expand(self, batch, predicted):
- out = list()
- for i, vec in enumerate(batch):
- expand_size = predicted[i].item()
- out.append(vec.expand(max(int(expand_size), 0), -1))
- out = torch.cat(out, 0)
- return out
- def forward(self, x, duration, max_len):
- output, mel_len = self.LR(x, duration, max_len)
- return output, mel_len
- class VariancePredictor(nn.Module):
- """Duration, Pitch and Energy Predictor"""
- def __init__(self, model_config):
- super(VariancePredictor, self).__init__()
- self.input_size = model_config["transformer"]["encoder_hidden"]
- self.filter_size = model_config["variance_predictor"]["filter_size"]
- self.kernel = model_config["variance_predictor"]["kernel_size"]
- self.conv_output_size = model_config["variance_predictor"]["filter_size"]
- self.dropout = model_config["variance_predictor"]["dropout"]
- self.conv_layer = nn.Sequential(
- OrderedDict(
- [
- (
- "conv1d_1",
- Conv(
- self.input_size,
- self.filter_size,
- kernel_size=self.kernel,
- padding=(self.kernel - 1) // 2,
- ),
- ),
- ("relu_1", nn.ReLU()),
- ("layer_norm_1", nn.LayerNorm(self.filter_size)),
- ("dropout_1", nn.Dropout(self.dropout)),
- (
- "conv1d_2",
- Conv(
- self.filter_size,
- self.filter_size,
- kernel_size=self.kernel,
- padding=1,
- ),
- ),
- ("relu_2", nn.ReLU()),
- ("layer_norm_2", nn.LayerNorm(self.filter_size)),
- ("dropout_2", nn.Dropout(self.dropout)),
- ]
- )
- )
- self.linear_layer = nn.Linear(self.conv_output_size, 1)
- def forward(self, encoder_output, mask):
- out = self.conv_layer(encoder_output)
- out = self.linear_layer(out)
- out = out.squeeze(-1)
- if mask is not None:
- out = out.masked_fill(mask, 0.0)
- return out
- class Conv(nn.Module):
- """
- Convolution Module
- """
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- dilation=1,
- bias=True,
- w_init="linear",
- ):
- """
- :param in_channels: dimension of input
- :param out_channels: dimension of output
- :param kernel_size: size of kernel
- :param stride: size of stride
- :param padding: size of padding
- :param dilation: dilation rate
- :param bias: boolean. if True, bias is included.
- :param w_init: str. weight inits with xavier initialization.
- """
- super(Conv, self).__init__()
- self.conv = nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias,
- )
- def forward(self, x):
- x = x.contiguous().transpose(1, 2)
- x = self.conv(x)
- x = x.contiguous().transpose(1, 2)
- return x
|