preprocessor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import os
  2. import random
  3. import json
  4. import tgt
  5. import librosa
  6. import numpy as np
  7. import pyworld as pw
  8. from scipy.interpolate import interp1d
  9. from sklearn.preprocessing import StandardScaler
  10. from tqdm import tqdm
  11. import audio as Audio
  12. class Preprocessor:
  13. def __init__(self, config):
  14. self.config = config
  15. self.in_dir = config["path"]["raw_path"]
  16. self.out_dir = config["path"]["preprocessed_path"]
  17. self.val_size = config["preprocessing"]["val_size"]
  18. self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
  19. self.hop_length = config["preprocessing"]["stft"]["hop_length"]
  20. assert config["preprocessing"]["pitch"]["feature"] in [
  21. "phoneme_level",
  22. "frame_level",
  23. ]
  24. assert config["preprocessing"]["energy"]["feature"] in [
  25. "phoneme_level",
  26. "frame_level",
  27. ]
  28. self.pitch_phoneme_averaging = (
  29. config["preprocessing"]["pitch"]["feature"] == "phoneme_level"
  30. )
  31. self.energy_phoneme_averaging = (
  32. config["preprocessing"]["energy"]["feature"] == "phoneme_level"
  33. )
  34. self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"]
  35. self.energy_normalization = config["preprocessing"]["energy"]["normalization"]
  36. self.STFT = Audio.stft.TacotronSTFT(
  37. config["preprocessing"]["stft"]["filter_length"],
  38. config["preprocessing"]["stft"]["hop_length"],
  39. config["preprocessing"]["stft"]["win_length"],
  40. config["preprocessing"]["mel"]["n_mel_channels"],
  41. config["preprocessing"]["audio"]["sampling_rate"],
  42. config["preprocessing"]["mel"]["mel_fmin"],
  43. config["preprocessing"]["mel"]["mel_fmax"],
  44. )
  45. def build_from_path(self):
  46. os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True)
  47. os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True)
  48. os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True)
  49. os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True)
  50. print("Processing Data ...")
  51. out = list()
  52. n_frames = 0
  53. pitch_scaler = StandardScaler()
  54. energy_scaler = StandardScaler()
  55. print(os.listdir(self.in_dir))
  56. # Compute pitch, energy, duration, and mel-spectrogram
  57. speakers = {}
  58. for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))):
  59. speakers[speaker] = i
  60. for wav_name in os.listdir(os.path.join(self.in_dir, speaker)):
  61. if ".wav" not in wav_name:
  62. continue
  63. basename = wav_name.split(".")[0]
  64. tg_path = os.path.join(
  65. self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
  66. )
  67. if os.path.exists(tg_path):
  68. ret = self.process_utterance(speaker, basename)
  69. if ret is None:
  70. continue
  71. else:
  72. info, pitch, energy, n = ret
  73. out.append(info)
  74. if len(pitch) > 0:
  75. pitch_scaler.partial_fit(pitch.reshape((-1, 1)))
  76. if len(energy) > 0:
  77. energy_scaler.partial_fit(energy.reshape((-1, 1)))
  78. n_frames += n
  79. print("Computing statistic quantities ...")
  80. # Perform normalization if necessary
  81. if self.pitch_normalization:
  82. pitch_mean = pitch_scaler.mean_[0]
  83. pitch_std = pitch_scaler.scale_[0]
  84. else:
  85. # A numerical trick to avoid normalization...
  86. pitch_mean = 0
  87. pitch_std = 1
  88. if self.energy_normalization:
  89. energy_mean = energy_scaler.mean_[0]
  90. energy_std = energy_scaler.scale_[0]
  91. else:
  92. energy_mean = 0
  93. energy_std = 1
  94. pitch_min, pitch_max = self.normalize(
  95. os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std
  96. )
  97. energy_min, energy_max = self.normalize(
  98. os.path.join(self.out_dir, "energy"), energy_mean, energy_std
  99. )
  100. # Save files
  101. with open(os.path.join(self.out_dir, "speakers.json"), "w") as f:
  102. f.write(json.dumps(speakers))
  103. with open(os.path.join(self.out_dir, "stats.json"), "w") as f:
  104. stats = {
  105. "pitch": [
  106. float(pitch_min),
  107. float(pitch_max),
  108. float(pitch_mean),
  109. float(pitch_std),
  110. ],
  111. "energy": [
  112. float(energy_min),
  113. float(energy_max),
  114. float(energy_mean),
  115. float(energy_std),
  116. ],
  117. }
  118. f.write(json.dumps(stats))
  119. print(
  120. "Total time: {} hours".format(
  121. n_frames * self.hop_length / self.sampling_rate / 3600
  122. )
  123. )
  124. random.shuffle(out)
  125. out = [r for r in out if r is not None]
  126. # Write metadata
  127. with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f:
  128. for m in out[self.val_size :]:
  129. f.write(m + "\n")
  130. with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f:
  131. for m in out[: self.val_size]:
  132. f.write(m + "\n")
  133. return out
  134. def process_utterance(self, speaker, basename):
  135. wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename))
  136. text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename))
  137. tg_path = os.path.join(
  138. self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
  139. )
  140. # Get alignments
  141. textgrid = tgt.io.read_textgrid(tg_path)
  142. phone, duration, start, end = self.get_alignment(
  143. textgrid.get_tier_by_name("phones")
  144. )
  145. text = "{" + " ".join(phone) + "}"
  146. if start >= end:
  147. return None
  148. # Read and trim wav files
  149. wav, _ = librosa.load(wav_path)
  150. wav = wav[
  151. int(self.sampling_rate * start) : int(self.sampling_rate * end)
  152. ].astype(np.float32)
  153. # Read raw text
  154. with open(text_path, "r") as f:
  155. raw_text = f.readline().strip("\n")
  156. # Compute fundamental frequency
  157. pitch, t = pw.dio(
  158. wav.astype(np.float64),
  159. self.sampling_rate,
  160. frame_period=self.hop_length / self.sampling_rate * 1000,
  161. )
  162. pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate)
  163. pitch = pitch[: sum(duration)]
  164. if np.sum(pitch != 0) <= 1:
  165. return None
  166. # Compute mel-scale spectrogram and energy
  167. mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT)
  168. mel_spectrogram = mel_spectrogram[:, : sum(duration)]
  169. energy = energy[: sum(duration)]
  170. if self.pitch_phoneme_averaging:
  171. # perform linear interpolation
  172. nonzero_ids = np.where(pitch != 0)[0]
  173. interp_fn = interp1d(
  174. nonzero_ids,
  175. pitch[nonzero_ids],
  176. fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
  177. bounds_error=False,
  178. )
  179. pitch = interp_fn(np.arange(0, len(pitch)))
  180. # Phoneme-level average
  181. pos = 0
  182. for i, d in enumerate(duration):
  183. if d > 0:
  184. pitch[i] = np.mean(pitch[pos : pos + d])
  185. else:
  186. pitch[i] = 0
  187. pos += d
  188. pitch = pitch[: len(duration)]
  189. if self.energy_phoneme_averaging:
  190. # Phoneme-level average
  191. pos = 0
  192. for i, d in enumerate(duration):
  193. if d > 0:
  194. energy[i] = np.mean(energy[pos : pos + d])
  195. else:
  196. energy[i] = 0
  197. pos += d
  198. energy = energy[: len(duration)]
  199. # Save files
  200. dur_filename = "{}-duration-{}.npy".format(speaker, basename)
  201. np.save(os.path.join(self.out_dir, "duration", dur_filename), duration)
  202. pitch_filename = "{}-pitch-{}.npy".format(speaker, basename)
  203. np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch)
  204. energy_filename = "{}-energy-{}.npy".format(speaker, basename)
  205. np.save(os.path.join(self.out_dir, "energy", energy_filename), energy)
  206. mel_filename = "{}-mel-{}.npy".format(speaker, basename)
  207. np.save(
  208. os.path.join(self.out_dir, "mel", mel_filename),
  209. mel_spectrogram.T,
  210. )
  211. return (
  212. "|".join([basename, speaker, text, raw_text]),
  213. self.remove_outlier(pitch),
  214. self.remove_outlier(energy),
  215. mel_spectrogram.shape[1],
  216. )
  217. def get_alignment(self, tier):
  218. sil_phones = ["sil", "sp", "spn"]
  219. phones = []
  220. durations = []
  221. start_time = 0
  222. end_time = 0
  223. end_idx = 0
  224. for t in tier._objects:
  225. s, e, p = t.start_time, t.end_time, t.text
  226. # Trim leading silences
  227. if phones == []:
  228. if p in sil_phones:
  229. continue
  230. else:
  231. start_time = s
  232. if p not in sil_phones:
  233. # For ordinary phones
  234. phones.append(p)
  235. end_time = e
  236. end_idx = len(phones)
  237. else:
  238. # For silent phones
  239. phones.append(p)
  240. durations.append(
  241. int(
  242. np.round(e * self.sampling_rate / self.hop_length)
  243. - np.round(s * self.sampling_rate / self.hop_length)
  244. )
  245. )
  246. # Trim tailing silences
  247. phones = phones[:end_idx]
  248. durations = durations[:end_idx]
  249. return phones, durations, start_time, end_time
  250. def remove_outlier(self, values):
  251. values = np.array(values)
  252. p25 = np.percentile(values, 25)
  253. p75 = np.percentile(values, 75)
  254. lower = p25 - 1.5 * (p75 - p25)
  255. upper = p75 + 1.5 * (p75 - p25)
  256. normal_indices = np.logical_and(values > lower, values < upper)
  257. return values[normal_indices]
  258. def normalize(self, in_dir, mean, std):
  259. max_value = np.finfo(np.float64).min
  260. min_value = np.finfo(np.float64).max
  261. for filename in os.listdir(in_dir):
  262. filename = os.path.join(in_dir, filename)
  263. values = (np.load(filename) - mean) / std
  264. np.save(filename, values)
  265. max_value = max(max_value, max(values))
  266. min_value = min(min_value, min(values))
  267. return min_value, max_value