123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn import Conv1d, ConvTranspose1d
- from torch.nn.utils import weight_norm, remove_weight_norm
- LRELU_SLOPE = 0.1
- def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
- def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
- class ResBlock(torch.nn.Module):
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
- super(ResBlock, self).__init__()
- self.h = h
- self.convs1 = nn.ModuleList(
- [
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[0],
- padding=get_padding(kernel_size, dilation[0]),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[1],
- padding=get_padding(kernel_size, dilation[1]),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[2],
- padding=get_padding(kernel_size, dilation[2]),
- )
- ),
- ]
- )
- self.convs1.apply(init_weights)
- self.convs2 = nn.ModuleList(
- [
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- ]
- )
- self.convs2.apply(init_weights)
- def forward(self, x):
- for c1, c2 in zip(self.convs1, self.convs2):
- xt = F.leaky_relu(x, LRELU_SLOPE)
- xt = c1(xt)
- xt = F.leaky_relu(xt, LRELU_SLOPE)
- xt = c2(xt)
- x = xt + x
- return x
- def remove_weight_norm(self):
- for l in self.convs1:
- remove_weight_norm(l)
- for l in self.convs2:
- remove_weight_norm(l)
- class Generator(torch.nn.Module):
- def __init__(self, h):
- super(Generator, self).__init__()
- self.h = h
- self.num_kernels = len(h.resblock_kernel_sizes)
- self.num_upsamples = len(h.upsample_rates)
- self.conv_pre = weight_norm(
- Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
- )
- resblock = ResBlock
- self.ups = nn.ModuleList()
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
- self.ups.append(
- weight_norm(
- ConvTranspose1d(
- h.upsample_initial_channel // (2 ** i),
- h.upsample_initial_channel // (2 ** (i + 1)),
- k,
- u,
- padding=(k - u) // 2,
- )
- )
- )
- self.resblocks = nn.ModuleList()
- for i in range(len(self.ups)):
- ch = h.upsample_initial_channel // (2 ** (i + 1))
- for j, (k, d) in enumerate(
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
- ):
- self.resblocks.append(resblock(h, ch, k, d))
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
- self.ups.apply(init_weights)
- self.conv_post.apply(init_weights)
- def forward(self, x):
- x = self.conv_pre(x)
- for i in range(self.num_upsamples):
- x = F.leaky_relu(x, LRELU_SLOPE)
- x = self.ups[i](x)
- xs = None
- for j in range(self.num_kernels):
- if xs is None:
- xs = self.resblocks[i * self.num_kernels + j](x)
- else:
- xs += self.resblocks[i * self.num_kernels + j](x)
- x = xs / self.num_kernels
- x = F.leaky_relu(x)
- x = self.conv_post(x)
- x = torch.tanh(x)
- return x
- def remove_weight_norm(self):
- print("Removing weight norm...")
- for l in self.ups:
- remove_weight_norm(l)
- for l in self.resblocks:
- l.remove_weight_norm()
- remove_weight_norm(self.conv_pre)
- remove_weight_norm(self.conv_post)
|