models.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import Conv1d, ConvTranspose1d
  5. from torch.nn.utils import weight_norm, remove_weight_norm
  6. LRELU_SLOPE = 0.1
  7. def init_weights(m, mean=0.0, std=0.01):
  8. classname = m.__class__.__name__
  9. if classname.find("Conv") != -1:
  10. m.weight.data.normal_(mean, std)
  11. def get_padding(kernel_size, dilation=1):
  12. return int((kernel_size * dilation - dilation) / 2)
  13. class ResBlock(torch.nn.Module):
  14. def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
  15. super(ResBlock, self).__init__()
  16. self.h = h
  17. self.convs1 = nn.ModuleList(
  18. [
  19. weight_norm(
  20. Conv1d(
  21. channels,
  22. channels,
  23. kernel_size,
  24. 1,
  25. dilation=dilation[0],
  26. padding=get_padding(kernel_size, dilation[0]),
  27. )
  28. ),
  29. weight_norm(
  30. Conv1d(
  31. channels,
  32. channels,
  33. kernel_size,
  34. 1,
  35. dilation=dilation[1],
  36. padding=get_padding(kernel_size, dilation[1]),
  37. )
  38. ),
  39. weight_norm(
  40. Conv1d(
  41. channels,
  42. channels,
  43. kernel_size,
  44. 1,
  45. dilation=dilation[2],
  46. padding=get_padding(kernel_size, dilation[2]),
  47. )
  48. ),
  49. ]
  50. )
  51. self.convs1.apply(init_weights)
  52. self.convs2 = nn.ModuleList(
  53. [
  54. weight_norm(
  55. Conv1d(
  56. channels,
  57. channels,
  58. kernel_size,
  59. 1,
  60. dilation=1,
  61. padding=get_padding(kernel_size, 1),
  62. )
  63. ),
  64. weight_norm(
  65. Conv1d(
  66. channels,
  67. channels,
  68. kernel_size,
  69. 1,
  70. dilation=1,
  71. padding=get_padding(kernel_size, 1),
  72. )
  73. ),
  74. weight_norm(
  75. Conv1d(
  76. channels,
  77. channels,
  78. kernel_size,
  79. 1,
  80. dilation=1,
  81. padding=get_padding(kernel_size, 1),
  82. )
  83. ),
  84. ]
  85. )
  86. self.convs2.apply(init_weights)
  87. def forward(self, x):
  88. for c1, c2 in zip(self.convs1, self.convs2):
  89. xt = F.leaky_relu(x, LRELU_SLOPE)
  90. xt = c1(xt)
  91. xt = F.leaky_relu(xt, LRELU_SLOPE)
  92. xt = c2(xt)
  93. x = xt + x
  94. return x
  95. def remove_weight_norm(self):
  96. for l in self.convs1:
  97. remove_weight_norm(l)
  98. for l in self.convs2:
  99. remove_weight_norm(l)
  100. class Generator(torch.nn.Module):
  101. def __init__(self, h):
  102. super(Generator, self).__init__()
  103. self.h = h
  104. self.num_kernels = len(h.resblock_kernel_sizes)
  105. self.num_upsamples = len(h.upsample_rates)
  106. self.conv_pre = weight_norm(
  107. Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
  108. )
  109. resblock = ResBlock
  110. self.ups = nn.ModuleList()
  111. for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
  112. self.ups.append(
  113. weight_norm(
  114. ConvTranspose1d(
  115. h.upsample_initial_channel // (2 ** i),
  116. h.upsample_initial_channel // (2 ** (i + 1)),
  117. k,
  118. u,
  119. padding=(k - u) // 2,
  120. )
  121. )
  122. )
  123. self.resblocks = nn.ModuleList()
  124. for i in range(len(self.ups)):
  125. ch = h.upsample_initial_channel // (2 ** (i + 1))
  126. for j, (k, d) in enumerate(
  127. zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
  128. ):
  129. self.resblocks.append(resblock(h, ch, k, d))
  130. self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
  131. self.ups.apply(init_weights)
  132. self.conv_post.apply(init_weights)
  133. def forward(self, x):
  134. x = self.conv_pre(x)
  135. for i in range(self.num_upsamples):
  136. x = F.leaky_relu(x, LRELU_SLOPE)
  137. x = self.ups[i](x)
  138. xs = None
  139. for j in range(self.num_kernels):
  140. if xs is None:
  141. xs = self.resblocks[i * self.num_kernels + j](x)
  142. else:
  143. xs += self.resblocks[i * self.num_kernels + j](x)
  144. x = xs / self.num_kernels
  145. x = F.leaky_relu(x)
  146. x = self.conv_post(x)
  147. x = torch.tanh(x)
  148. return x
  149. def remove_weight_norm(self):
  150. print("Removing weight norm...")
  151. for l in self.ups:
  152. remove_weight_norm(l)
  153. for l in self.resblocks:
  154. l.remove_weight_norm()
  155. remove_weight_norm(self.conv_pre)
  156. remove_weight_norm(self.conv_post)