1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
| class HiFiGANVocoder(nn.Module):
def __init__(self, config: TTSConfig):
super().__init__()
self.config = config
# 生成器
self.generator = HiFiGANGenerator(config)
# 多尺度判别器
self.msd = MultiScaleDiscriminator()
# 多周期判别器
self.mpd = MultiPeriodDiscriminator()
def forward(self, mel_spectrogram: torch.Tensor):
"""生成高保真音频"""
return self.generator(mel_spectrogram)
def train_step(self, mel: torch.Tensor, audio: torch.Tensor):
"""训练步骤"""
# 生成音频
audio_fake = self.generator(mel)
# 判别器损失
d_loss = self.discriminator_loss(audio, audio_fake.detach())
# 生成器损失
g_loss = self.generator_loss(mel, audio, audio_fake)
return g_loss, d_loss
class HiFiGANGenerator(nn.Module):
def __init__(self, config: TTSConfig):
super().__init__()
# 输入卷积
self.conv_pre = nn.Conv1d(
config.n_mels,
config.hidden_dim,
kernel_size=7,
padding=3
)
# 上采样块
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip([8, 8, 2, 2], [16, 16, 4, 4])):
self.ups.append(
nn.ConvTranspose1d(
config.hidden_dim // (2**i),
config.hidden_dim // (2**(i+1)),
kernel_size=k,
stride=u,
padding=(k-u)//2
)
)
# 多感受野融合块
self.mrfs = nn.ModuleList([
MultiReceptiveFieldFusion(
config.hidden_dim // (2**(i+1)),
[3, 7, 11],
[1, 3, 5]
)
for i in range(4)
])
# 输出卷积
self.conv_post = nn.Conv1d(
config.hidden_dim // 16,
1,
kernel_size=7,
padding=3
)
def forward(self, mel: torch.Tensor):
"""生成音频"""
x = self.conv_pre(mel)
for up, mrf in zip(self.ups, self.mrfs):
x = torch.relu(up(x))
x = mrf(x)
audio = torch.tanh(self.conv_post(x))
return audio
|