train.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import argparse
  2. import os
  3. import torch
  4. import yaml
  5. import torch.nn as nn
  6. from torch.utils.data import DataLoader
  7. from torch.utils.tensorboard import SummaryWriter
  8. from tqdm import tqdm
  9. from utils.model import get_model, get_vocoder, get_param_num
  10. from utils.tools import to_device, log, synth_one_sample
  11. from model import FastSpeech2Loss
  12. from dataset import Dataset
  13. from evaluate import evaluate
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. def main(args, configs):
  16. print("Prepare training ...")
  17. preprocess_config, model_config, train_config = configs
  18. # Get dataset
  19. dataset = Dataset(
  20. "train.txt", preprocess_config, train_config, sort=True, drop_last=True
  21. )
  22. batch_size = train_config["optimizer"]["batch_size"]
  23. group_size = 4 # Set this larger than 1 to enable sorting in Dataset
  24. assert batch_size * group_size < len(dataset)
  25. loader = DataLoader(
  26. dataset,
  27. batch_size=batch_size * group_size,
  28. shuffle=True,
  29. collate_fn=dataset.collate_fn,
  30. )
  31. # Prepare model
  32. model, optimizer = get_model(args, configs, device, train=True)
  33. model = nn.DataParallel(model)
  34. num_param = get_param_num(model)
  35. Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
  36. print("Number of FastSpeech2 Parameters:", num_param)
  37. # Load vocoder
  38. vocoder = get_vocoder(model_config, device)
  39. # Init logger
  40. for p in train_config["path"].values():
  41. os.makedirs(p, exist_ok=True)
  42. train_log_path = os.path.join(train_config["path"]["log_path"], "train")
  43. val_log_path = os.path.join(train_config["path"]["log_path"], "val")
  44. os.makedirs(train_log_path, exist_ok=True)
  45. os.makedirs(val_log_path, exist_ok=True)
  46. train_logger = SummaryWriter(train_log_path)
  47. val_logger = SummaryWriter(val_log_path)
  48. # Training
  49. step = args.restore_step + 1
  50. epoch = 1
  51. grad_acc_step = train_config["optimizer"]["grad_acc_step"]
  52. grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
  53. total_step = train_config["step"]["total_step"]
  54. log_step = train_config["step"]["log_step"]
  55. save_step = train_config["step"]["save_step"]
  56. synth_step = train_config["step"]["synth_step"]
  57. val_step = train_config["step"]["val_step"]
  58. outer_bar = tqdm(total=total_step, desc="Training", position=0)
  59. outer_bar.n = args.restore_step
  60. outer_bar.update()
  61. while True:
  62. inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
  63. for batchs in loader:
  64. for batch in batchs:
  65. batch = to_device(batch, device)
  66. # Forward
  67. output = model(*(batch[2:]))
  68. # Cal Loss
  69. losses = Loss(batch, output)
  70. total_loss = losses[0]
  71. # Backward
  72. total_loss = total_loss / grad_acc_step
  73. total_loss.backward()
  74. if step % grad_acc_step == 0:
  75. # Clipping gradients to avoid gradient explosion
  76. nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
  77. # Update weights
  78. optimizer.step_and_update_lr()
  79. optimizer.zero_grad()
  80. if step % log_step == 0:
  81. losses = [l.item() for l in losses]
  82. message1 = "Step {}/{}, ".format(step, total_step)
  83. message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
  84. *losses
  85. )
  86. with open(os.path.join(train_log_path, "log.txt"), "a") as f:
  87. f.write(message1 + message2 + "\n")
  88. outer_bar.write(message1 + message2)
  89. log(train_logger, step, losses=losses)
  90. if step % synth_step == 0:
  91. fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
  92. batch,
  93. output,
  94. vocoder,
  95. model_config,
  96. preprocess_config,
  97. )
  98. log(
  99. train_logger,
  100. fig=fig,
  101. tag="Training/step_{}_{}".format(step, tag),
  102. )
  103. sampling_rate = preprocess_config["preprocessing"]["audio"][
  104. "sampling_rate"
  105. ]
  106. log(
  107. train_logger,
  108. audio=wav_reconstruction,
  109. sampling_rate=sampling_rate,
  110. tag="Training/step_{}_{}_reconstructed".format(step, tag),
  111. )
  112. log(
  113. train_logger,
  114. audio=wav_prediction,
  115. sampling_rate=sampling_rate,
  116. tag="Training/step_{}_{}_synthesized".format(step, tag),
  117. )
  118. if step % val_step == 0:
  119. model.eval()
  120. message = evaluate(model, step, configs, val_logger, vocoder)
  121. with open(os.path.join(val_log_path, "log.txt"), "a") as f:
  122. f.write(message + "\n")
  123. outer_bar.write(message)
  124. model.train()
  125. if step % save_step == 0:
  126. torch.save(
  127. {
  128. "model": model.module.state_dict(),
  129. "optimizer": optimizer._optimizer.state_dict(),
  130. },
  131. os.path.join(
  132. train_config["path"]["ckpt_path"],
  133. "{}.pth.tar".format(step),
  134. ),
  135. )
  136. if step == total_step:
  137. quit()
  138. step += 1
  139. outer_bar.update(1)
  140. inner_bar.update(1)
  141. epoch += 1
  142. if __name__ == "__main__":
  143. parser = argparse.ArgumentParser()
  144. parser.add_argument("--restore_step", type=int, default=0)
  145. parser.add_argument(
  146. "-p",
  147. "--preprocess_config",
  148. type=str,
  149. required=True,
  150. help="path to preprocess.yaml",
  151. )
  152. parser.add_argument(
  153. "-m", "--model_config", type=str, required=True, help="path to model.yaml"
  154. )
  155. parser.add_argument(
  156. "-t", "--train_config", type=str, required=True, help="path to train.yaml"
  157. )
  158. args = parser.parse_args()
  159. # Read Config
  160. preprocess_config = yaml.load(
  161. open(args.preprocess_config, "r"), Loader=yaml.FullLoader
  162. )
  163. model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
  164. train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
  165. configs = (preprocess_config, model_config, train_config)
  166. main(args, configs)