synthesize.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import re
  2. import argparse
  3. from string import punctuation
  4. import torch
  5. import yaml
  6. import numpy as np
  7. from torch.utils.data import DataLoader
  8. from g2p_en import G2p
  9. from pypinyin import pinyin, Style
  10. from utils.model import get_model, get_vocoder
  11. from utils.tools import to_device, synth_samples
  12. from dataset import TextDataset
  13. from text import text_to_sequence
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. def read_lexicon(lex_path):
  16. lexicon = {}
  17. with open(lex_path) as f:
  18. for line in f:
  19. temp = re.split(r"\s+", line.strip("\n"))
  20. word = temp[0]
  21. phones = temp[1:]
  22. if word.lower() not in lexicon:
  23. lexicon[word.lower()] = phones
  24. return lexicon
  25. def preprocess_english(text, preprocess_config):
  26. text = text.rstrip(punctuation)
  27. lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
  28. g2p = G2p()
  29. phones = []
  30. words = re.split(r"([,;.\-\?\!\s+])", text)
  31. for w in words:
  32. if w.lower() in lexicon:
  33. phones += lexicon[w.lower()]
  34. else:
  35. phones += list(filter(lambda p: p != " ", g2p(w)))
  36. phones = "{" + "}{".join(phones) + "}"
  37. phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
  38. phones = phones.replace("}{", " ")
  39. print("Raw Text Sequence: {}".format(text))
  40. print("Phoneme Sequence: {}".format(phones))
  41. sequence = np.array(
  42. text_to_sequence(
  43. phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
  44. )
  45. )
  46. return np.array(sequence)
  47. def preprocess_mandarin(text, preprocess_config):
  48. lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
  49. phones = []
  50. pinyins = [
  51. p[0]
  52. for p in pinyin(
  53. text, style=Style.TONE3, strict=False, neutral_tone_with_five=True
  54. )
  55. ]
  56. for p in pinyins:
  57. if p in lexicon:
  58. phones += lexicon[p]
  59. else:
  60. phones.append("sp")
  61. phones = "{" + " ".join(phones) + "}"
  62. print("Raw Text Sequence: {}".format(text))
  63. print("Phoneme Sequence: {}".format(phones))
  64. sequence = np.array(
  65. text_to_sequence(
  66. phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
  67. )
  68. )
  69. return np.array(sequence)
  70. def synthesize(model, step, configs, vocoder, batchs, control_values):
  71. preprocess_config, model_config, train_config = configs
  72. pitch_control, energy_control, duration_control = control_values
  73. for batch in batchs:
  74. batch = to_device(batch, device)
  75. with torch.no_grad():
  76. # Forward
  77. output = model(
  78. *(batch[2:]),
  79. p_control=pitch_control,
  80. e_control=energy_control,
  81. d_control=duration_control
  82. )
  83. synth_samples(
  84. batch,
  85. output,
  86. vocoder,
  87. model_config,
  88. preprocess_config,
  89. train_config["path"]["result_path"],
  90. )
  91. if __name__ == "__main__":
  92. parser = argparse.ArgumentParser()
  93. parser.add_argument("--restore_step", type=int, required=True)
  94. parser.add_argument(
  95. "--mode",
  96. type=str,
  97. choices=["batch", "single"],
  98. required=True,
  99. help="Synthesize a whole dataset or a single sentence",
  100. )
  101. parser.add_argument(
  102. "--source",
  103. type=str,
  104. default=None,
  105. help="path to a source file with format like train.txt and val.txt, for batch mode only",
  106. )
  107. parser.add_argument(
  108. "--text",
  109. type=str,
  110. default=None,
  111. help="raw text to synthesize, for single-sentence mode only",
  112. )
  113. parser.add_argument(
  114. "--speaker_id",
  115. type=int,
  116. default=0,
  117. help="speaker ID for multi-speaker synthesis, for single-sentence mode only",
  118. )
  119. parser.add_argument(
  120. "-p",
  121. "--preprocess_config",
  122. type=str,
  123. required=True,
  124. help="path to preprocess.yaml",
  125. )
  126. parser.add_argument(
  127. "-m", "--model_config", type=str, required=True, help="path to model.yaml"
  128. )
  129. parser.add_argument(
  130. "-t", "--train_config", type=str, required=True, help="path to train.yaml"
  131. )
  132. parser.add_argument(
  133. "--pitch_control",
  134. type=float,
  135. default=1.0,
  136. help="control the pitch of the whole utterance, larger value for higher pitch",
  137. )
  138. parser.add_argument(
  139. "--energy_control",
  140. type=float,
  141. default=1.0,
  142. help="control the energy of the whole utterance, larger value for larger volume",
  143. )
  144. parser.add_argument(
  145. "--duration_control",
  146. type=float,
  147. default=1.0,
  148. help="control the speed of the whole utterance, larger value for slower speaking rate",
  149. )
  150. args = parser.parse_args()
  151. # Check source texts
  152. if args.mode == "batch":
  153. assert args.source is not None and args.text is None
  154. if args.mode == "single":
  155. assert args.source is None and args.text is not None
  156. # Read Config
  157. preprocess_config = yaml.load(
  158. open(args.preprocess_config, "r"), Loader=yaml.FullLoader
  159. )
  160. model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
  161. train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
  162. configs = (preprocess_config, model_config, train_config)
  163. # Get model
  164. model = get_model(args, configs, device, train=False)
  165. # Load vocoder
  166. vocoder = get_vocoder(model_config, device)
  167. # Preprocess texts
  168. if args.mode == "batch":
  169. # Get dataset
  170. dataset = TextDataset(args.source, preprocess_config)
  171. batchs = DataLoader(
  172. dataset,
  173. batch_size=8,
  174. collate_fn=dataset.collate_fn,
  175. )
  176. if args.mode == "single":
  177. ids = raw_texts = [args.text[:100]]
  178. speakers = np.array([args.speaker_id])
  179. if preprocess_config["preprocessing"]["text"]["language"] == "en":
  180. texts = np.array([preprocess_english(args.text, preprocess_config)])
  181. elif preprocess_config["preprocessing"]["text"]["language"] == "zh":
  182. texts = np.array([preprocess_mandarin(args.text, preprocess_config)])
  183. text_lens = np.array([len(texts[0])])
  184. batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]
  185. control_values = args.pitch_control, args.energy_control, args.duration_control
  186. synthesize(model, args.restore_step, configs, vocoder, batchs, control_values)