tools.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import os
  2. import json
  3. import torch
  4. import torch.nn.functional as F
  5. import numpy as np
  6. import matplotlib
  7. from scipy.io import wavfile
  8. from matplotlib import pyplot as plt
  9. matplotlib.use("Agg")
  10. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. def to_device(data, device):
  12. if len(data) == 12:
  13. (
  14. ids,
  15. raw_texts,
  16. speakers,
  17. texts,
  18. src_lens,
  19. max_src_len,
  20. mels,
  21. mel_lens,
  22. max_mel_len,
  23. pitches,
  24. energies,
  25. durations,
  26. ) = data
  27. speakers = torch.from_numpy(speakers).long().to(device)
  28. texts = torch.from_numpy(texts).long().to(device)
  29. src_lens = torch.from_numpy(src_lens).to(device)
  30. mels = torch.from_numpy(mels).float().to(device)
  31. mel_lens = torch.from_numpy(mel_lens).to(device)
  32. pitches = torch.from_numpy(pitches).float().to(device)
  33. energies = torch.from_numpy(energies).to(device)
  34. durations = torch.from_numpy(durations).long().to(device)
  35. return (
  36. ids,
  37. raw_texts,
  38. speakers,
  39. texts,
  40. src_lens,
  41. max_src_len,
  42. mels,
  43. mel_lens,
  44. max_mel_len,
  45. pitches,
  46. energies,
  47. durations,
  48. )
  49. if len(data) == 6:
  50. (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
  51. speakers = torch.from_numpy(speakers).long().to(device)
  52. texts = torch.from_numpy(texts).long().to(device)
  53. src_lens = torch.from_numpy(src_lens).to(device)
  54. return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
  55. def log(
  56. logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""
  57. ):
  58. if losses is not None:
  59. logger.add_scalar("Loss/total_loss", losses[0], step)
  60. logger.add_scalar("Loss/mel_loss", losses[1], step)
  61. logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
  62. logger.add_scalar("Loss/pitch_loss", losses[3], step)
  63. logger.add_scalar("Loss/energy_loss", losses[4], step)
  64. logger.add_scalar("Loss/duration_loss", losses[5], step)
  65. if fig is not None:
  66. logger.add_figure(tag, fig)
  67. if audio is not None:
  68. logger.add_audio(
  69. tag,
  70. audio / max(abs(audio)),
  71. sample_rate=sampling_rate,
  72. )
  73. def get_mask_from_lengths(lengths, max_len=None):
  74. batch_size = lengths.shape[0]
  75. if max_len is None:
  76. max_len = torch.max(lengths).item()
  77. ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
  78. mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
  79. return mask
  80. def expand(values, durations):
  81. out = list()
  82. for value, d in zip(values, durations):
  83. out += [value] * max(0, int(d))
  84. return np.array(out)
  85. def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config):
  86. basename = targets[0][0]
  87. src_len = predictions[8][0].item()
  88. mel_len = predictions[9][0].item()
  89. mel_target = targets[6][0, :mel_len].detach().transpose(0, 1)
  90. mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1)
  91. duration = targets[11][0, :src_len].detach().cpu().numpy()
  92. if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
  93. pitch = targets[9][0, :src_len].detach().cpu().numpy()
  94. pitch = expand(pitch, duration)
  95. else:
  96. pitch = targets[9][0, :mel_len].detach().cpu().numpy()
  97. if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
  98. energy = targets[10][0, :src_len].detach().cpu().numpy()
  99. energy = expand(energy, duration)
  100. else:
  101. energy = targets[10][0, :mel_len].detach().cpu().numpy()
  102. with open(
  103. os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
  104. ) as f:
  105. stats = json.load(f)
  106. stats = stats["pitch"] + stats["energy"][:2]
  107. fig = plot_mel(
  108. [
  109. (mel_prediction.cpu().numpy(), pitch, energy),
  110. (mel_target.cpu().numpy(), pitch, energy),
  111. ],
  112. stats,
  113. ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
  114. )
  115. if vocoder is not None:
  116. from .model import vocoder_infer
  117. wav_reconstruction = vocoder_infer(
  118. mel_target.unsqueeze(0),
  119. vocoder,
  120. model_config,
  121. preprocess_config,
  122. )[0]
  123. wav_prediction = vocoder_infer(
  124. mel_prediction.unsqueeze(0),
  125. vocoder,
  126. model_config,
  127. preprocess_config,
  128. )[0]
  129. else:
  130. wav_reconstruction = wav_prediction = None
  131. return fig, wav_reconstruction, wav_prediction, basename
  132. def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
  133. basenames = targets[0]
  134. for i in range(len(predictions[0])):
  135. basename = basenames[i]
  136. src_len = predictions[8][i].item()
  137. mel_len = predictions[9][i].item()
  138. mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
  139. duration = predictions[5][i, :src_len].detach().cpu().numpy()
  140. if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
  141. pitch = predictions[2][i, :src_len].detach().cpu().numpy()
  142. pitch = expand(pitch, duration)
  143. else:
  144. pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
  145. if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
  146. energy = predictions[3][i, :src_len].detach().cpu().numpy()
  147. energy = expand(energy, duration)
  148. else:
  149. energy = predictions[3][i, :mel_len].detach().cpu().numpy()
  150. with open(
  151. os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
  152. ) as f:
  153. stats = json.load(f)
  154. stats = stats["pitch"] + stats["energy"][:2]
  155. fig = plot_mel(
  156. [
  157. (mel_prediction.cpu().numpy(), pitch, energy),
  158. ],
  159. stats,
  160. ["Synthetized Spectrogram"],
  161. )
  162. plt.savefig(os.path.join(path, "{}.png".format(basename)))
  163. plt.close()
  164. from .model import vocoder_infer
  165. mel_predictions = predictions[1].transpose(1, 2)
  166. lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
  167. wav_predictions = vocoder_infer(
  168. mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
  169. )
  170. sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
  171. for wav, basename in zip(wav_predictions, basenames):
  172. wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
  173. def plot_mel(data, stats, titles):
  174. fig, axes = plt.subplots(len(data), 1, squeeze=False)
  175. if titles is None:
  176. titles = [None for i in range(len(data))]
  177. pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
  178. pitch_min = pitch_min * pitch_std + pitch_mean
  179. pitch_max = pitch_max * pitch_std + pitch_mean
  180. def add_axis(fig, old_ax):
  181. ax = fig.add_axes(old_ax.get_position(), anchor="W")
  182. ax.set_facecolor("None")
  183. return ax
  184. for i in range(len(data)):
  185. mel, pitch, energy = data[i]
  186. pitch = pitch * pitch_std + pitch_mean
  187. axes[i][0].imshow(mel, origin="lower")
  188. axes[i][0].set_aspect(2.5, adjustable="box")
  189. axes[i][0].set_ylim(0, mel.shape[0])
  190. axes[i][0].set_title(titles[i], fontsize="medium")
  191. axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
  192. axes[i][0].set_anchor("W")
  193. ax1 = add_axis(fig, axes[i][0])
  194. ax1.plot(pitch, color="tomato")
  195. ax1.set_xlim(0, mel.shape[1])
  196. ax1.set_ylim(0, pitch_max)
  197. ax1.set_ylabel("F0", color="tomato")
  198. ax1.tick_params(
  199. labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
  200. )
  201. ax2 = add_axis(fig, axes[i][0])
  202. ax2.plot(energy, color="darkviolet")
  203. ax2.set_xlim(0, mel.shape[1])
  204. ax2.set_ylim(energy_min, energy_max)
  205. ax2.set_ylabel("Energy", color="darkviolet")
  206. ax2.yaxis.set_label_position("right")
  207. ax2.tick_params(
  208. labelsize="x-small",
  209. colors="darkviolet",
  210. bottom=False,
  211. labelbottom=False,
  212. left=False,
  213. labelleft=False,
  214. right=True,
  215. labelright=True,
  216. )
  217. return fig
  218. def pad_1D(inputs, PAD=0):
  219. def pad_data(x, length, PAD):
  220. x_padded = np.pad(
  221. x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
  222. )
  223. return x_padded
  224. max_len = max((len(x) for x in inputs))
  225. padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
  226. return padded
  227. def pad_2D(inputs, maxlen=None):
  228. def pad(x, max_len):
  229. PAD = 0
  230. if np.shape(x)[0] > max_len:
  231. raise ValueError("not max_len")
  232. s = np.shape(x)[1]
  233. x_padded = np.pad(
  234. x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
  235. )
  236. return x_padded[:, :s]
  237. if maxlen:
  238. output = np.stack([pad(x, maxlen) for x in inputs])
  239. else:
  240. max_len = max(np.shape(x)[0] for x in inputs)
  241. output = np.stack([pad(x, max_len) for x in inputs])
  242. return output
  243. def pad(input_ele, mel_max_length=None):
  244. if mel_max_length:
  245. max_len = mel_max_length
  246. else:
  247. max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
  248. out_list = list()
  249. for i, batch in enumerate(input_ele):
  250. if len(batch.shape) == 1:
  251. one_batch_padded = F.pad(
  252. batch, (0, max_len - batch.size(0)), "constant", 0.0
  253. )
  254. elif len(batch.shape) == 2:
  255. one_batch_padded = F.pad(
  256. batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
  257. )
  258. out_list.append(one_batch_padded)
  259. out_padded = torch.stack(out_list)
  260. return out_padded