dataset.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import json
  2. import math
  3. import os
  4. import numpy as np
  5. from torch.utils.data import Dataset
  6. from text import text_to_sequence
  7. from utils.tools import pad_1D, pad_2D
  8. class Dataset(Dataset):
  9. def __init__(
  10. self, filename, preprocess_config, train_config, sort=False, drop_last=False
  11. ):
  12. self.dataset_name = preprocess_config["dataset"]
  13. self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
  14. self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
  15. self.batch_size = train_config["optimizer"]["batch_size"]
  16. self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
  17. filename
  18. )
  19. with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
  20. self.speaker_map = json.load(f)
  21. self.sort = sort
  22. self.drop_last = drop_last
  23. def __len__(self):
  24. return len(self.text)
  25. def __getitem__(self, idx):
  26. basename = self.basename[idx]
  27. speaker = self.speaker[idx]
  28. speaker_id = self.speaker_map[speaker]
  29. raw_text = self.raw_text[idx]
  30. phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
  31. mel_path = os.path.join(
  32. self.preprocessed_path,
  33. "mel",
  34. "{}-mel-{}.npy".format(speaker, basename),
  35. )
  36. mel = np.load(mel_path)
  37. pitch_path = os.path.join(
  38. self.preprocessed_path,
  39. "pitch",
  40. "{}-pitch-{}.npy".format(speaker, basename),
  41. )
  42. pitch = np.load(pitch_path)
  43. energy_path = os.path.join(
  44. self.preprocessed_path,
  45. "energy",
  46. "{}-energy-{}.npy".format(speaker, basename),
  47. )
  48. energy = np.load(energy_path)
  49. duration_path = os.path.join(
  50. self.preprocessed_path,
  51. "duration",
  52. "{}-duration-{}.npy".format(speaker, basename),
  53. )
  54. duration = np.load(duration_path)
  55. sample = {
  56. "id": basename,
  57. "speaker": speaker_id,
  58. "text": phone,
  59. "raw_text": raw_text,
  60. "mel": mel,
  61. "pitch": pitch,
  62. "energy": energy,
  63. "duration": duration,
  64. }
  65. return sample
  66. def process_meta(self, filename):
  67. with open(
  68. os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
  69. ) as f:
  70. name = []
  71. speaker = []
  72. text = []
  73. raw_text = []
  74. for line in f.readlines():
  75. n, s, t, r = line.strip("\n").split("|")
  76. name.append(n)
  77. speaker.append(s)
  78. text.append(t)
  79. raw_text.append(r)
  80. return name, speaker, text, raw_text
  81. def reprocess(self, data, idxs):
  82. ids = [data[idx]["id"] for idx in idxs]
  83. speakers = [data[idx]["speaker"] for idx in idxs]
  84. texts = [data[idx]["text"] for idx in idxs]
  85. raw_texts = [data[idx]["raw_text"] for idx in idxs]
  86. mels = [data[idx]["mel"] for idx in idxs]
  87. pitches = [data[idx]["pitch"] for idx in idxs]
  88. energies = [data[idx]["energy"] for idx in idxs]
  89. durations = [data[idx]["duration"] for idx in idxs]
  90. text_lens = np.array([text.shape[0] for text in texts])
  91. mel_lens = np.array([mel.shape[0] for mel in mels])
  92. speakers = np.array(speakers)
  93. texts = pad_1D(texts)
  94. mels = pad_2D(mels)
  95. pitches = pad_1D(pitches)
  96. energies = pad_1D(energies)
  97. durations = pad_1D(durations)
  98. return (
  99. ids,
  100. raw_texts,
  101. speakers,
  102. texts,
  103. text_lens,
  104. max(text_lens),
  105. mels,
  106. mel_lens,
  107. max(mel_lens),
  108. pitches,
  109. energies,
  110. durations,
  111. )
  112. def collate_fn(self, data):
  113. data_size = len(data)
  114. if self.sort:
  115. len_arr = np.array([d["text"].shape[0] for d in data])
  116. idx_arr = np.argsort(-len_arr)
  117. else:
  118. idx_arr = np.arange(data_size)
  119. tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
  120. idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
  121. idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
  122. if not self.drop_last and len(tail) > 0:
  123. idx_arr += [tail.tolist()]
  124. output = list()
  125. for idx in idx_arr:
  126. output.append(self.reprocess(data, idx))
  127. return output
  128. class TextDataset(Dataset):
  129. def __init__(self, filepath, preprocess_config):
  130. self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
  131. self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
  132. filepath
  133. )
  134. with open(
  135. os.path.join(
  136. preprocess_config["path"]["preprocessed_path"], "speakers.json"
  137. )
  138. ) as f:
  139. self.speaker_map = json.load(f)
  140. def __len__(self):
  141. return len(self.text)
  142. def __getitem__(self, idx):
  143. basename = self.basename[idx]
  144. speaker = self.speaker[idx]
  145. speaker_id = self.speaker_map[speaker]
  146. raw_text = self.raw_text[idx]
  147. phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
  148. return (basename, speaker_id, phone, raw_text)
  149. def process_meta(self, filename):
  150. with open(filename, "r", encoding="utf-8") as f:
  151. name = []
  152. speaker = []
  153. text = []
  154. raw_text = []
  155. for line in f.readlines():
  156. n, s, t, r = line.strip("\n").split("|")
  157. name.append(n)
  158. speaker.append(s)
  159. text.append(t)
  160. raw_text.append(r)
  161. return name, speaker, text, raw_text
  162. def collate_fn(self, data):
  163. ids = [d[0] for d in data]
  164. speakers = np.array([d[1] for d in data])
  165. texts = [d[2] for d in data]
  166. raw_texts = [d[3] for d in data]
  167. text_lens = np.array([text.shape[0] for text in texts])
  168. texts = pad_1D(texts)
  169. return ids, raw_texts, speakers, texts, text_lens, max(text_lens)
  170. if __name__ == "__main__":
  171. # Test
  172. import torch
  173. import yaml
  174. from torch.utils.data import DataLoader
  175. from utils.utils import to_device
  176. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  177. preprocess_config = yaml.load(
  178. open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader
  179. )
  180. train_config = yaml.load(
  181. open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader
  182. )
  183. train_dataset = Dataset(
  184. "train.txt", preprocess_config, train_config, sort=True, drop_last=True
  185. )
  186. val_dataset = Dataset(
  187. "val.txt", preprocess_config, train_config, sort=False, drop_last=False
  188. )
  189. train_loader = DataLoader(
  190. train_dataset,
  191. batch_size=train_config["optimizer"]["batch_size"] * 4,
  192. shuffle=True,
  193. collate_fn=train_dataset.collate_fn,
  194. )
  195. val_loader = DataLoader(
  196. val_dataset,
  197. batch_size=train_config["optimizer"]["batch_size"],
  198. shuffle=False,
  199. collate_fn=val_dataset.collate_fn,
  200. )
  201. n_batch = 0
  202. for batchs in train_loader:
  203. for batch in batchs:
  204. to_device(batch, device)
  205. n_batch += 1
  206. print(
  207. "Training set with size {} is composed of {} batches.".format(
  208. len(train_dataset), n_batch
  209. )
  210. )
  211. n_batch = 0
  212. for batchs in val_loader:
  213. for batch in batchs:
  214. to_device(batch, device)
  215. n_batch += 1
  216. print(
  217. "Validation set with size {} is composed of {} batches.".format(
  218. len(val_dataset), n_batch
  219. )
  220. )