123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- import os
- import json
- import torch
- import torch.nn.functional as F
- import numpy as np
- import matplotlib
- from scipy.io import wavfile
- from matplotlib import pyplot as plt
- matplotlib.use("Agg")
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- def to_device(data, device):
- if len(data) == 12:
- (
- ids,
- raw_texts,
- speakers,
- texts,
- src_lens,
- max_src_len,
- mels,
- mel_lens,
- max_mel_len,
- pitches,
- energies,
- durations,
- ) = data
- speakers = torch.from_numpy(speakers).long().to(device)
- texts = torch.from_numpy(texts).long().to(device)
- src_lens = torch.from_numpy(src_lens).to(device)
- mels = torch.from_numpy(mels).float().to(device)
- mel_lens = torch.from_numpy(mel_lens).to(device)
- pitches = torch.from_numpy(pitches).float().to(device)
- energies = torch.from_numpy(energies).to(device)
- durations = torch.from_numpy(durations).long().to(device)
- return (
- ids,
- raw_texts,
- speakers,
- texts,
- src_lens,
- max_src_len,
- mels,
- mel_lens,
- max_mel_len,
- pitches,
- energies,
- durations,
- )
- if len(data) == 6:
- (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
- speakers = torch.from_numpy(speakers).long().to(device)
- texts = torch.from_numpy(texts).long().to(device)
- src_lens = torch.from_numpy(src_lens).to(device)
- return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
- def log(
- logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""
- ):
- if losses is not None:
- logger.add_scalar("Loss/total_loss", losses[0], step)
- logger.add_scalar("Loss/mel_loss", losses[1], step)
- logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
- logger.add_scalar("Loss/pitch_loss", losses[3], step)
- logger.add_scalar("Loss/energy_loss", losses[4], step)
- logger.add_scalar("Loss/duration_loss", losses[5], step)
- if fig is not None:
- logger.add_figure(tag, fig)
- if audio is not None:
- logger.add_audio(
- tag,
- audio / max(abs(audio)),
- sample_rate=sampling_rate,
- )
- def get_mask_from_lengths(lengths, max_len=None):
- batch_size = lengths.shape[0]
- if max_len is None:
- max_len = torch.max(lengths).item()
- ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
- mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
- return mask
- def expand(values, durations):
- out = list()
- for value, d in zip(values, durations):
- out += [value] * max(0, int(d))
- return np.array(out)
- def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config):
- basename = targets[0][0]
- src_len = predictions[8][0].item()
- mel_len = predictions[9][0].item()
- mel_target = targets[6][0, :mel_len].detach().transpose(0, 1)
- mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1)
- duration = targets[11][0, :src_len].detach().cpu().numpy()
- if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
- pitch = targets[9][0, :src_len].detach().cpu().numpy()
- pitch = expand(pitch, duration)
- else:
- pitch = targets[9][0, :mel_len].detach().cpu().numpy()
- if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
- energy = targets[10][0, :src_len].detach().cpu().numpy()
- energy = expand(energy, duration)
- else:
- energy = targets[10][0, :mel_len].detach().cpu().numpy()
- with open(
- os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
- ) as f:
- stats = json.load(f)
- stats = stats["pitch"] + stats["energy"][:2]
- fig = plot_mel(
- [
- (mel_prediction.cpu().numpy(), pitch, energy),
- (mel_target.cpu().numpy(), pitch, energy),
- ],
- stats,
- ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
- )
- if vocoder is not None:
- from .model import vocoder_infer
- wav_reconstruction = vocoder_infer(
- mel_target.unsqueeze(0),
- vocoder,
- model_config,
- preprocess_config,
- )[0]
- wav_prediction = vocoder_infer(
- mel_prediction.unsqueeze(0),
- vocoder,
- model_config,
- preprocess_config,
- )[0]
- else:
- wav_reconstruction = wav_prediction = None
- return fig, wav_reconstruction, wav_prediction, basename
- def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
- basenames = targets[0]
- for i in range(len(predictions[0])):
- basename = basenames[i]
- src_len = predictions[8][i].item()
- mel_len = predictions[9][i].item()
- mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
- duration = predictions[5][i, :src_len].detach().cpu().numpy()
- if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
- pitch = predictions[2][i, :src_len].detach().cpu().numpy()
- pitch = expand(pitch, duration)
- else:
- pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
- if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
- energy = predictions[3][i, :src_len].detach().cpu().numpy()
- energy = expand(energy, duration)
- else:
- energy = predictions[3][i, :mel_len].detach().cpu().numpy()
- with open(
- os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
- ) as f:
- stats = json.load(f)
- stats = stats["pitch"] + stats["energy"][:2]
- fig = plot_mel(
- [
- (mel_prediction.cpu().numpy(), pitch, energy),
- ],
- stats,
- ["Synthetized Spectrogram"],
- )
- plt.savefig(os.path.join(path, "{}.png".format(basename)))
- plt.close()
- from .model import vocoder_infer
- mel_predictions = predictions[1].transpose(1, 2)
- lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
- wav_predictions = vocoder_infer(
- mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
- )
- sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
- for wav, basename in zip(wav_predictions, basenames):
- wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
- def plot_mel(data, stats, titles):
- fig, axes = plt.subplots(len(data), 1, squeeze=False)
- if titles is None:
- titles = [None for i in range(len(data))]
- pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
- pitch_min = pitch_min * pitch_std + pitch_mean
- pitch_max = pitch_max * pitch_std + pitch_mean
- def add_axis(fig, old_ax):
- ax = fig.add_axes(old_ax.get_position(), anchor="W")
- ax.set_facecolor("None")
- return ax
- for i in range(len(data)):
- mel, pitch, energy = data[i]
- pitch = pitch * pitch_std + pitch_mean
- axes[i][0].imshow(mel, origin="lower")
- axes[i][0].set_aspect(2.5, adjustable="box")
- axes[i][0].set_ylim(0, mel.shape[0])
- axes[i][0].set_title(titles[i], fontsize="medium")
- axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
- axes[i][0].set_anchor("W")
- ax1 = add_axis(fig, axes[i][0])
- ax1.plot(pitch, color="tomato")
- ax1.set_xlim(0, mel.shape[1])
- ax1.set_ylim(0, pitch_max)
- ax1.set_ylabel("F0", color="tomato")
- ax1.tick_params(
- labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
- )
- ax2 = add_axis(fig, axes[i][0])
- ax2.plot(energy, color="darkviolet")
- ax2.set_xlim(0, mel.shape[1])
- ax2.set_ylim(energy_min, energy_max)
- ax2.set_ylabel("Energy", color="darkviolet")
- ax2.yaxis.set_label_position("right")
- ax2.tick_params(
- labelsize="x-small",
- colors="darkviolet",
- bottom=False,
- labelbottom=False,
- left=False,
- labelleft=False,
- right=True,
- labelright=True,
- )
- return fig
- def pad_1D(inputs, PAD=0):
- def pad_data(x, length, PAD):
- x_padded = np.pad(
- x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
- )
- return x_padded
- max_len = max((len(x) for x in inputs))
- padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
- return padded
- def pad_2D(inputs, maxlen=None):
- def pad(x, max_len):
- PAD = 0
- if np.shape(x)[0] > max_len:
- raise ValueError("not max_len")
- s = np.shape(x)[1]
- x_padded = np.pad(
- x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
- )
- return x_padded[:, :s]
- if maxlen:
- output = np.stack([pad(x, maxlen) for x in inputs])
- else:
- max_len = max(np.shape(x)[0] for x in inputs)
- output = np.stack([pad(x, max_len) for x in inputs])
- return output
- def pad(input_ele, mel_max_length=None):
- if mel_max_length:
- max_len = mel_max_length
- else:
- max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
- out_list = list()
- for i, batch in enumerate(input_ele):
- if len(batch.shape) == 1:
- one_batch_padded = F.pad(
- batch, (0, max_len - batch.size(0)), "constant", 0.0
- )
- elif len(batch.shape) == 2:
- one_batch_padded = F.pad(
- batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
- )
- out_list.append(one_batch_padded)
- out_padded = torch.stack(out_list)
- return out_padded
|