loss.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torch
  2. import torch.nn as nn
  3. class FastSpeech2Loss(nn.Module):
  4. """ FastSpeech2 Loss """
  5. def __init__(self, preprocess_config, model_config):
  6. super(FastSpeech2Loss, self).__init__()
  7. self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
  8. "feature"
  9. ]
  10. self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
  11. "feature"
  12. ]
  13. self.mse_loss = nn.MSELoss()
  14. self.mae_loss = nn.L1Loss()
  15. def forward(self, inputs, predictions):
  16. (
  17. mel_targets,
  18. _,
  19. _,
  20. pitch_targets,
  21. energy_targets,
  22. duration_targets,
  23. ) = inputs[6:]
  24. (
  25. mel_predictions,
  26. postnet_mel_predictions,
  27. pitch_predictions,
  28. energy_predictions,
  29. log_duration_predictions,
  30. _,
  31. src_masks,
  32. mel_masks,
  33. _,
  34. _,
  35. ) = predictions
  36. src_masks = ~src_masks
  37. mel_masks = ~mel_masks
  38. log_duration_targets = torch.log(duration_targets.float() + 1)
  39. mel_targets = mel_targets[:, : mel_masks.shape[1], :]
  40. mel_masks = mel_masks[:, :mel_masks.shape[1]]
  41. log_duration_targets.requires_grad = False
  42. pitch_targets.requires_grad = False
  43. energy_targets.requires_grad = False
  44. mel_targets.requires_grad = False
  45. if self.pitch_feature_level == "phoneme_level":
  46. pitch_predictions = pitch_predictions.masked_select(src_masks)
  47. pitch_targets = pitch_targets.masked_select(src_masks)
  48. elif self.pitch_feature_level == "frame_level":
  49. pitch_predictions = pitch_predictions.masked_select(mel_masks)
  50. pitch_targets = pitch_targets.masked_select(mel_masks)
  51. if self.energy_feature_level == "phoneme_level":
  52. energy_predictions = energy_predictions.masked_select(src_masks)
  53. energy_targets = energy_targets.masked_select(src_masks)
  54. if self.energy_feature_level == "frame_level":
  55. energy_predictions = energy_predictions.masked_select(mel_masks)
  56. energy_targets = energy_targets.masked_select(mel_masks)
  57. log_duration_predictions = log_duration_predictions.masked_select(src_masks)
  58. log_duration_targets = log_duration_targets.masked_select(src_masks)
  59. mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
  60. postnet_mel_predictions = postnet_mel_predictions.masked_select(
  61. mel_masks.unsqueeze(-1)
  62. )
  63. mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
  64. mel_loss = self.mae_loss(mel_predictions, mel_targets)
  65. postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
  66. pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
  67. energy_loss = self.mse_loss(energy_predictions, energy_targets)
  68. duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
  69. total_loss = (
  70. mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
  71. )
  72. return (
  73. total_loss,
  74. mel_loss,
  75. postnet_mel_loss,
  76. pitch_loss,
  77. energy_loss,
  78. duration_loss,
  79. )