fastspeech2.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import json
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from transformer import Encoder, Decoder, PostNet
  7. from .modules import VarianceAdaptor
  8. from utils.tools import get_mask_from_lengths
  9. class FastSpeech2(nn.Module):
  10. """ FastSpeech2 """
  11. def __init__(self, preprocess_config, model_config):
  12. super(FastSpeech2, self).__init__()
  13. self.model_config = model_config
  14. self.encoder = Encoder(model_config)
  15. self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
  16. self.decoder = Decoder(model_config)
  17. self.mel_linear = nn.Linear(
  18. model_config["transformer"]["decoder_hidden"],
  19. preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
  20. )
  21. self.postnet = PostNet()
  22. self.speaker_emb = None
  23. if model_config["multi_speaker"]:
  24. with open(
  25. os.path.join(
  26. preprocess_config["path"]["preprocessed_path"], "speakers.json"
  27. ),
  28. "r",
  29. ) as f:
  30. n_speaker = len(json.load(f))
  31. self.speaker_emb = nn.Embedding(
  32. n_speaker,
  33. model_config["transformer"]["encoder_hidden"],
  34. )
  35. def forward(
  36. self,
  37. speakers,
  38. texts,
  39. src_lens,
  40. max_src_len,
  41. mels=None,
  42. mel_lens=None,
  43. max_mel_len=None,
  44. p_targets=None,
  45. e_targets=None,
  46. d_targets=None,
  47. p_control=1.0,
  48. e_control=1.0,
  49. d_control=1.0,
  50. ):
  51. src_masks = get_mask_from_lengths(src_lens, max_src_len)
  52. mel_masks = (
  53. get_mask_from_lengths(mel_lens, max_mel_len)
  54. if mel_lens is not None
  55. else None
  56. )
  57. output = self.encoder(texts, src_masks)
  58. if self.speaker_emb is not None:
  59. output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
  60. -1, max_src_len, -1
  61. )
  62. (
  63. output,
  64. p_predictions,
  65. e_predictions,
  66. log_d_predictions,
  67. d_rounded,
  68. mel_lens,
  69. mel_masks,
  70. ) = self.variance_adaptor(
  71. output,
  72. src_masks,
  73. mel_masks,
  74. max_mel_len,
  75. p_targets,
  76. e_targets,
  77. d_targets,
  78. p_control,
  79. e_control,
  80. d_control,
  81. )
  82. output, mel_masks = self.decoder(output, mel_masks)
  83. output = self.mel_linear(output)
  84. postnet_output = self.postnet(output) + output
  85. return (
  86. output,
  87. postnet_output,
  88. p_predictions,
  89. e_predictions,
  90. log_d_predictions,
  91. d_rounded,
  92. src_masks,
  93. mel_masks,
  94. src_lens,
  95. mel_lens,
  96. )