stft.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import torch
  2. import torch.nn.functional as F
  3. import numpy as np
  4. from scipy.signal import get_window
  5. from librosa.util import pad_center, tiny
  6. from librosa.filters import mel as librosa_mel_fn
  7. from audio.audio_processing import (
  8. dynamic_range_compression,
  9. dynamic_range_decompression,
  10. window_sumsquare,
  11. )
  12. class STFT(torch.nn.Module):
  13. """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
  14. def __init__(self, filter_length, hop_length, win_length, window="hann"):
  15. super(STFT, self).__init__()
  16. self.filter_length = filter_length
  17. self.hop_length = hop_length
  18. self.win_length = win_length
  19. self.window = window
  20. self.forward_transform = None
  21. scale = self.filter_length / self.hop_length
  22. fourier_basis = np.fft.fft(np.eye(self.filter_length))
  23. cutoff = int((self.filter_length / 2 + 1))
  24. fourier_basis = np.vstack(
  25. [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
  26. )
  27. forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
  28. inverse_basis = torch.FloatTensor(
  29. np.linalg.pinv(scale * fourier_basis).T[:, None, :]
  30. )
  31. if window is not None:
  32. assert filter_length >= win_length
  33. # get window and zero center pad it to filter_length
  34. fft_window = get_window(window, win_length, fftbins=True)
  35. fft_window = pad_center(fft_window, filter_length)
  36. fft_window = torch.from_numpy(fft_window).float()
  37. # window the bases
  38. forward_basis *= fft_window
  39. inverse_basis *= fft_window
  40. self.register_buffer("forward_basis", forward_basis.float())
  41. self.register_buffer("inverse_basis", inverse_basis.float())
  42. def transform(self, input_data):
  43. num_batches = input_data.size(0)
  44. num_samples = input_data.size(1)
  45. self.num_samples = num_samples
  46. # similar to librosa, reflect-pad the input
  47. input_data = input_data.view(num_batches, 1, num_samples)
  48. input_data = F.pad(
  49. input_data.unsqueeze(1),
  50. (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
  51. mode="reflect",
  52. )
  53. input_data = input_data.squeeze(1)
  54. forward_transform = F.conv1d(
  55. input_data.cuda(),
  56. torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
  57. stride=self.hop_length,
  58. padding=0,
  59. ).cpu()
  60. cutoff = int((self.filter_length / 2) + 1)
  61. real_part = forward_transform[:, :cutoff, :]
  62. imag_part = forward_transform[:, cutoff:, :]
  63. magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
  64. phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
  65. return magnitude, phase
  66. def inverse(self, magnitude, phase):
  67. recombine_magnitude_phase = torch.cat(
  68. [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
  69. )
  70. inverse_transform = F.conv_transpose1d(
  71. recombine_magnitude_phase,
  72. torch.autograd.Variable(self.inverse_basis, requires_grad=False),
  73. stride=self.hop_length,
  74. padding=0,
  75. )
  76. if self.window is not None:
  77. window_sum = window_sumsquare(
  78. self.window,
  79. magnitude.size(-1),
  80. hop_length=self.hop_length,
  81. win_length=self.win_length,
  82. n_fft=self.filter_length,
  83. dtype=np.float32,
  84. )
  85. # remove modulation effects
  86. approx_nonzero_indices = torch.from_numpy(
  87. np.where(window_sum > tiny(window_sum))[0]
  88. )
  89. window_sum = torch.autograd.Variable(
  90. torch.from_numpy(window_sum), requires_grad=False
  91. )
  92. window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
  93. inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
  94. approx_nonzero_indices
  95. ]
  96. # scale by hop ratio
  97. inverse_transform *= float(self.filter_length) / self.hop_length
  98. inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
  99. inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
  100. return inverse_transform
  101. def forward(self, input_data):
  102. self.magnitude, self.phase = self.transform(input_data)
  103. reconstruction = self.inverse(self.magnitude, self.phase)
  104. return reconstruction
  105. class TacotronSTFT(torch.nn.Module):
  106. def __init__(
  107. self,
  108. filter_length,
  109. hop_length,
  110. win_length,
  111. n_mel_channels,
  112. sampling_rate,
  113. mel_fmin,
  114. mel_fmax,
  115. ):
  116. super(TacotronSTFT, self).__init__()
  117. self.n_mel_channels = n_mel_channels
  118. self.sampling_rate = sampling_rate
  119. self.stft_fn = STFT(filter_length, hop_length, win_length)
  120. mel_basis = librosa_mel_fn(
  121. sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
  122. )
  123. mel_basis = torch.from_numpy(mel_basis).float()
  124. self.register_buffer("mel_basis", mel_basis)
  125. def spectral_normalize(self, magnitudes):
  126. output = dynamic_range_compression(magnitudes)
  127. return output
  128. def spectral_de_normalize(self, magnitudes):
  129. output = dynamic_range_decompression(magnitudes)
  130. return output
  131. def mel_spectrogram(self, y):
  132. """Computes mel-spectrograms from a batch of waves
  133. PARAMS
  134. ------
  135. y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
  136. RETURNS
  137. -------
  138. mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
  139. """
  140. assert torch.min(y.data) >= -1
  141. assert torch.max(y.data) <= 1
  142. magnitudes, phases = self.stft_fn.transform(y)
  143. magnitudes = magnitudes.data
  144. mel_output = torch.matmul(self.mel_basis, magnitudes)
  145. mel_output = self.spectral_normalize(mel_output)
  146. energy = torch.norm(magnitudes, dim=1)
  147. return mel_output, energy