evaluate.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import argparse
  2. import os
  3. import torch
  4. import yaml
  5. import torch.nn as nn
  6. from torch.utils.data import DataLoader
  7. from utils.model import get_model, get_vocoder
  8. from utils.tools import to_device, log, synth_one_sample
  9. from model import FastSpeech2Loss
  10. from dataset import Dataset
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. def evaluate(model, step, configs, logger=None, vocoder=None):
  13. preprocess_config, model_config, train_config = configs
  14. # Get dataset
  15. dataset = Dataset(
  16. "val.txt", preprocess_config, train_config, sort=False, drop_last=False
  17. )
  18. batch_size = train_config["optimizer"]["batch_size"]
  19. loader = DataLoader(
  20. dataset,
  21. batch_size=batch_size,
  22. shuffle=False,
  23. collate_fn=dataset.collate_fn,
  24. )
  25. # Get loss function
  26. Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
  27. # Evaluation
  28. loss_sums = [0 for _ in range(6)]
  29. for batchs in loader:
  30. for batch in batchs:
  31. batch = to_device(batch, device)
  32. with torch.no_grad():
  33. # Forward
  34. output = model(*(batch[2:]))
  35. # Cal Loss
  36. losses = Loss(batch, output)
  37. for i in range(len(losses)):
  38. loss_sums[i] += losses[i].item() * len(batch[0])
  39. loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums]
  40. message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
  41. *([step] + [l for l in loss_means])
  42. )
  43. if logger is not None:
  44. fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
  45. batch,
  46. output,
  47. vocoder,
  48. model_config,
  49. preprocess_config,
  50. )
  51. log(logger, step, losses=loss_means)
  52. log(
  53. logger,
  54. fig=fig,
  55. tag="Validation/step_{}_{}".format(step, tag),
  56. )
  57. sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
  58. log(
  59. logger,
  60. audio=wav_reconstruction,
  61. sampling_rate=sampling_rate,
  62. tag="Validation/step_{}_{}_reconstructed".format(step, tag),
  63. )
  64. log(
  65. logger,
  66. audio=wav_prediction,
  67. sampling_rate=sampling_rate,
  68. tag="Validation/step_{}_{}_synthesized".format(step, tag),
  69. )
  70. return message
  71. if __name__ == "__main__":
  72. parser = argparse.ArgumentParser()
  73. parser.add_argument("--restore_step", type=int, default=30000)
  74. parser.add_argument(
  75. "-p",
  76. "--preprocess_config",
  77. type=str,
  78. required=True,
  79. help="path to preprocess.yaml",
  80. )
  81. parser.add_argument(
  82. "-m", "--model_config", type=str, required=True, help="path to model.yaml"
  83. )
  84. parser.add_argument(
  85. "-t", "--train_config", type=str, required=True, help="path to train.yaml"
  86. )
  87. args = parser.parse_args()
  88. # Read Config
  89. preprocess_config = yaml.load(
  90. open(args.preprocess_config, "r"), Loader=yaml.FullLoader
  91. )
  92. model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
  93. train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
  94. configs = (preprocess_config, model_config, train_config)
  95. # Get model
  96. model = get_model(args, configs, device, train=False).to(device)
  97. message = evaluate(model, args.restore_step, configs)
  98. print(message)