modules.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import os
  2. import json
  3. import copy
  4. import math
  5. from collections import OrderedDict
  6. import torch
  7. import torch.nn as nn
  8. import numpy as np
  9. import torch.nn.functional as F
  10. from utils.tools import get_mask_from_lengths, pad
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. class VarianceAdaptor(nn.Module):
  13. """Variance Adaptor"""
  14. def __init__(self, preprocess_config, model_config):
  15. super(VarianceAdaptor, self).__init__()
  16. self.duration_predictor = VariancePredictor(model_config)
  17. self.length_regulator = LengthRegulator()
  18. self.pitch_predictor = VariancePredictor(model_config)
  19. self.energy_predictor = VariancePredictor(model_config)
  20. self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
  21. "feature"
  22. ]
  23. self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
  24. "feature"
  25. ]
  26. assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
  27. assert self.energy_feature_level in ["phoneme_level", "frame_level"]
  28. pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
  29. energy_quantization = model_config["variance_embedding"]["energy_quantization"]
  30. n_bins = model_config["variance_embedding"]["n_bins"]
  31. assert pitch_quantization in ["linear", "log"]
  32. assert energy_quantization in ["linear", "log"]
  33. with open(
  34. os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
  35. ) as f:
  36. stats = json.load(f)
  37. pitch_min, pitch_max = stats["pitch"][:2]
  38. energy_min, energy_max = stats["energy"][:2]
  39. if pitch_quantization == "log":
  40. self.pitch_bins = nn.Parameter(
  41. torch.exp(
  42. torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
  43. ),
  44. requires_grad=False,
  45. )
  46. else:
  47. self.pitch_bins = nn.Parameter(
  48. torch.linspace(pitch_min, pitch_max, n_bins - 1),
  49. requires_grad=False,
  50. )
  51. if energy_quantization == "log":
  52. self.energy_bins = nn.Parameter(
  53. torch.exp(
  54. torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
  55. ),
  56. requires_grad=False,
  57. )
  58. else:
  59. self.energy_bins = nn.Parameter(
  60. torch.linspace(energy_min, energy_max, n_bins - 1),
  61. requires_grad=False,
  62. )
  63. self.pitch_embedding = nn.Embedding(
  64. n_bins, model_config["transformer"]["encoder_hidden"]
  65. )
  66. self.energy_embedding = nn.Embedding(
  67. n_bins, model_config["transformer"]["encoder_hidden"]
  68. )
  69. def get_pitch_embedding(self, x, target, mask, control):
  70. prediction = self.pitch_predictor(x, mask)
  71. if target is not None:
  72. embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
  73. else:
  74. prediction = prediction * control
  75. embedding = self.pitch_embedding(
  76. torch.bucketize(prediction, self.pitch_bins)
  77. )
  78. return prediction, embedding
  79. def get_energy_embedding(self, x, target, mask, control):
  80. prediction = self.energy_predictor(x, mask)
  81. if target is not None:
  82. embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
  83. else:
  84. prediction = prediction * control
  85. embedding = self.energy_embedding(
  86. torch.bucketize(prediction, self.energy_bins)
  87. )
  88. return prediction, embedding
  89. def forward(
  90. self,
  91. x,
  92. src_mask,
  93. mel_mask=None,
  94. max_len=None,
  95. pitch_target=None,
  96. energy_target=None,
  97. duration_target=None,
  98. p_control=1.0,
  99. e_control=1.0,
  100. d_control=1.0,
  101. ):
  102. log_duration_prediction = self.duration_predictor(x, src_mask)
  103. if self.pitch_feature_level == "phoneme_level":
  104. pitch_prediction, pitch_embedding = self.get_pitch_embedding(
  105. x, pitch_target, src_mask, p_control
  106. )
  107. x = x + pitch_embedding
  108. if self.energy_feature_level == "phoneme_level":
  109. energy_prediction, energy_embedding = self.get_energy_embedding(
  110. x, energy_target, src_mask, p_control
  111. )
  112. x = x + energy_embedding
  113. if duration_target is not None:
  114. x, mel_len = self.length_regulator(x, duration_target, max_len)
  115. duration_rounded = duration_target
  116. else:
  117. duration_rounded = torch.clamp(
  118. (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
  119. min=0,
  120. )
  121. x, mel_len = self.length_regulator(x, duration_rounded, max_len)
  122. mel_mask = get_mask_from_lengths(mel_len)
  123. if self.pitch_feature_level == "frame_level":
  124. pitch_prediction, pitch_embedding = self.get_pitch_embedding(
  125. x, pitch_target, mel_mask, p_control
  126. )
  127. x = x + pitch_embedding
  128. if self.energy_feature_level == "frame_level":
  129. energy_prediction, energy_embedding = self.get_energy_embedding(
  130. x, energy_target, mel_mask, p_control
  131. )
  132. x = x + energy_embedding
  133. return (
  134. x,
  135. pitch_prediction,
  136. energy_prediction,
  137. log_duration_prediction,
  138. duration_rounded,
  139. mel_len,
  140. mel_mask,
  141. )
  142. class LengthRegulator(nn.Module):
  143. """Length Regulator"""
  144. def __init__(self):
  145. super(LengthRegulator, self).__init__()
  146. def LR(self, x, duration, max_len):
  147. output = list()
  148. mel_len = list()
  149. for batch, expand_target in zip(x, duration):
  150. expanded = self.expand(batch, expand_target)
  151. output.append(expanded)
  152. mel_len.append(expanded.shape[0])
  153. if max_len is not None:
  154. output = pad(output, max_len)
  155. else:
  156. output = pad(output)
  157. return output, torch.LongTensor(mel_len).to(device)
  158. def expand(self, batch, predicted):
  159. out = list()
  160. for i, vec in enumerate(batch):
  161. expand_size = predicted[i].item()
  162. out.append(vec.expand(max(int(expand_size), 0), -1))
  163. out = torch.cat(out, 0)
  164. return out
  165. def forward(self, x, duration, max_len):
  166. output, mel_len = self.LR(x, duration, max_len)
  167. return output, mel_len
  168. class VariancePredictor(nn.Module):
  169. """Duration, Pitch and Energy Predictor"""
  170. def __init__(self, model_config):
  171. super(VariancePredictor, self).__init__()
  172. self.input_size = model_config["transformer"]["encoder_hidden"]
  173. self.filter_size = model_config["variance_predictor"]["filter_size"]
  174. self.kernel = model_config["variance_predictor"]["kernel_size"]
  175. self.conv_output_size = model_config["variance_predictor"]["filter_size"]
  176. self.dropout = model_config["variance_predictor"]["dropout"]
  177. self.conv_layer = nn.Sequential(
  178. OrderedDict(
  179. [
  180. (
  181. "conv1d_1",
  182. Conv(
  183. self.input_size,
  184. self.filter_size,
  185. kernel_size=self.kernel,
  186. padding=(self.kernel - 1) // 2,
  187. ),
  188. ),
  189. ("relu_1", nn.ReLU()),
  190. ("layer_norm_1", nn.LayerNorm(self.filter_size)),
  191. ("dropout_1", nn.Dropout(self.dropout)),
  192. (
  193. "conv1d_2",
  194. Conv(
  195. self.filter_size,
  196. self.filter_size,
  197. kernel_size=self.kernel,
  198. padding=1,
  199. ),
  200. ),
  201. ("relu_2", nn.ReLU()),
  202. ("layer_norm_2", nn.LayerNorm(self.filter_size)),
  203. ("dropout_2", nn.Dropout(self.dropout)),
  204. ]
  205. )
  206. )
  207. self.linear_layer = nn.Linear(self.conv_output_size, 1)
  208. def forward(self, encoder_output, mask):
  209. out = self.conv_layer(encoder_output)
  210. out = self.linear_layer(out)
  211. out = out.squeeze(-1)
  212. if mask is not None:
  213. out = out.masked_fill(mask, 0.0)
  214. return out
  215. class Conv(nn.Module):
  216. """
  217. Convolution Module
  218. """
  219. def __init__(
  220. self,
  221. in_channels,
  222. out_channels,
  223. kernel_size=1,
  224. stride=1,
  225. padding=0,
  226. dilation=1,
  227. bias=True,
  228. w_init="linear",
  229. ):
  230. """
  231. :param in_channels: dimension of input
  232. :param out_channels: dimension of output
  233. :param kernel_size: size of kernel
  234. :param stride: size of stride
  235. :param padding: size of padding
  236. :param dilation: dilation rate
  237. :param bias: boolean. if True, bias is included.
  238. :param w_init: str. weight inits with xavier initialization.
  239. """
  240. super(Conv, self).__init__()
  241. self.conv = nn.Conv1d(
  242. in_channels,
  243. out_channels,
  244. kernel_size=kernel_size,
  245. stride=stride,
  246. padding=padding,
  247. dilation=dilation,
  248. bias=bias,
  249. )
  250. def forward(self, x):
  251. x = x.contiguous().transpose(1, 2)
  252. x = self.conv(x)
  253. x = x.contiguous().transpose(1, 2)
  254. return x