123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import argparse
- import os
- import torch
- import yaml
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torch.utils.tensorboard import SummaryWriter
- from tqdm import tqdm
- from utils.model import get_model, get_vocoder, get_param_num
- from utils.tools import to_device, log, synth_one_sample
- from model import FastSpeech2Loss
- from dataset import Dataset
- from evaluate import evaluate
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- def main(args, configs):
- print("Prepare training ...")
- preprocess_config, model_config, train_config = configs
- # Get dataset
- dataset = Dataset(
- "train.txt", preprocess_config, train_config, sort=True, drop_last=True
- )
- batch_size = train_config["optimizer"]["batch_size"]
- group_size = 4 # Set this larger than 1 to enable sorting in Dataset
- assert batch_size * group_size < len(dataset)
- loader = DataLoader(
- dataset,
- batch_size=batch_size * group_size,
- shuffle=True,
- collate_fn=dataset.collate_fn,
- )
- # Prepare model
- model, optimizer = get_model(args, configs, device, train=True)
- model = nn.DataParallel(model)
- num_param = get_param_num(model)
- Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
- print("Number of FastSpeech2 Parameters:", num_param)
- # Load vocoder
- vocoder = get_vocoder(model_config, device)
- # Init logger
- for p in train_config["path"].values():
- os.makedirs(p, exist_ok=True)
- train_log_path = os.path.join(train_config["path"]["log_path"], "train")
- val_log_path = os.path.join(train_config["path"]["log_path"], "val")
- os.makedirs(train_log_path, exist_ok=True)
- os.makedirs(val_log_path, exist_ok=True)
- train_logger = SummaryWriter(train_log_path)
- val_logger = SummaryWriter(val_log_path)
- # Training
- step = args.restore_step + 1
- epoch = 1
- grad_acc_step = train_config["optimizer"]["grad_acc_step"]
- grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
- total_step = train_config["step"]["total_step"]
- log_step = train_config["step"]["log_step"]
- save_step = train_config["step"]["save_step"]
- synth_step = train_config["step"]["synth_step"]
- val_step = train_config["step"]["val_step"]
- outer_bar = tqdm(total=total_step, desc="Training", position=0)
- outer_bar.n = args.restore_step
- outer_bar.update()
- while True:
- inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
- for batchs in loader:
- for batch in batchs:
- batch = to_device(batch, device)
- # Forward
- output = model(*(batch[2:]))
- # Cal Loss
- losses = Loss(batch, output)
- total_loss = losses[0]
- # Backward
- total_loss = total_loss / grad_acc_step
- total_loss.backward()
- if step % grad_acc_step == 0:
- # Clipping gradients to avoid gradient explosion
- nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
- # Update weights
- optimizer.step_and_update_lr()
- optimizer.zero_grad()
- if step % log_step == 0:
- losses = [l.item() for l in losses]
- message1 = "Step {}/{}, ".format(step, total_step)
- message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
- *losses
- )
- with open(os.path.join(train_log_path, "log.txt"), "a") as f:
- f.write(message1 + message2 + "\n")
- outer_bar.write(message1 + message2)
- log(train_logger, step, losses=losses)
- if step % synth_step == 0:
- fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
- batch,
- output,
- vocoder,
- model_config,
- preprocess_config,
- )
- log(
- train_logger,
- fig=fig,
- tag="Training/step_{}_{}".format(step, tag),
- )
- sampling_rate = preprocess_config["preprocessing"]["audio"][
- "sampling_rate"
- ]
- log(
- train_logger,
- audio=wav_reconstruction,
- sampling_rate=sampling_rate,
- tag="Training/step_{}_{}_reconstructed".format(step, tag),
- )
- log(
- train_logger,
- audio=wav_prediction,
- sampling_rate=sampling_rate,
- tag="Training/step_{}_{}_synthesized".format(step, tag),
- )
- if step % val_step == 0:
- model.eval()
- message = evaluate(model, step, configs, val_logger, vocoder)
- with open(os.path.join(val_log_path, "log.txt"), "a") as f:
- f.write(message + "\n")
- outer_bar.write(message)
- model.train()
- if step % save_step == 0:
- torch.save(
- {
- "model": model.module.state_dict(),
- "optimizer": optimizer._optimizer.state_dict(),
- },
- os.path.join(
- train_config["path"]["ckpt_path"],
- "{}.pth.tar".format(step),
- ),
- )
- if step == total_step:
- quit()
- step += 1
- outer_bar.update(1)
- inner_bar.update(1)
- epoch += 1
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--restore_step", type=int, default=0)
- parser.add_argument(
- "-p",
- "--preprocess_config",
- type=str,
- required=True,
- help="path to preprocess.yaml",
- )
- parser.add_argument(
- "-m", "--model_config", type=str, required=True, help="path to model.yaml"
- )
- parser.add_argument(
- "-t", "--train_config", type=str, required=True, help="path to train.yaml"
- )
- args = parser.parse_args()
- # Read Config
- preprocess_config = yaml.load(
- open(args.preprocess_config, "r"), Loader=yaml.FullLoader
- )
- model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
- train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
- configs = (preprocess_config, model_config, train_config)
- main(args, configs)
|