mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
460 lines
14 KiB
Python
460 lines
14 KiB
Python
# Copyright (c) 2023 Amphion.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
from typing import Optional, Any
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
import torchaudio.functional as audio_F
|
|
|
|
import random
|
|
|
|
random.seed(0)
|
|
|
|
|
|
def _get_activation_fn(activ):
|
|
if activ == "relu":
|
|
return nn.ReLU()
|
|
elif activ == "lrelu":
|
|
return nn.LeakyReLU(0.2)
|
|
elif activ == "swish":
|
|
return lambda x: x * torch.sigmoid(x)
|
|
else:
|
|
raise RuntimeError(
|
|
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
|
|
)
|
|
|
|
|
|
class LinearNorm(torch.nn.Module):
|
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
|
super(LinearNorm, self).__init__()
|
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
|
|
|
torch.nn.init.xavier_uniform_(
|
|
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.linear_layer(x)
|
|
|
|
|
|
class ConvNorm(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=None,
|
|
dilation=1,
|
|
bias=True,
|
|
w_init_gain="linear",
|
|
param=None,
|
|
):
|
|
super(ConvNorm, self).__init__()
|
|
if padding is None:
|
|
assert kernel_size % 2 == 1
|
|
padding = int(dilation * (kernel_size - 1) / 2)
|
|
|
|
self.conv = torch.nn.Conv1d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias,
|
|
)
|
|
|
|
torch.nn.init.xavier_uniform_(
|
|
self.conv.weight,
|
|
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
|
)
|
|
|
|
def forward(self, signal):
|
|
conv_signal = self.conv(signal)
|
|
return conv_signal
|
|
|
|
|
|
class CausualConv(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1,
|
|
bias=True,
|
|
w_init_gain="linear",
|
|
param=None,
|
|
):
|
|
super(CausualConv, self).__init__()
|
|
if padding is None:
|
|
assert kernel_size % 2 == 1
|
|
padding = int(dilation * (kernel_size - 1) / 2) * 2
|
|
else:
|
|
self.padding = padding * 2
|
|
self.conv = nn.Conv1d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=self.padding,
|
|
dilation=dilation,
|
|
bias=bias,
|
|
)
|
|
|
|
torch.nn.init.xavier_uniform_(
|
|
self.conv.weight,
|
|
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = x[:, :, : -self.padding]
|
|
return x
|
|
|
|
|
|
class CausualBlock(nn.Module):
|
|
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
|
|
super(CausualBlock, self).__init__()
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
self._get_conv(
|
|
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
|
)
|
|
for i in range(n_conv)
|
|
]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for block in self.blocks:
|
|
res = x
|
|
x = block(x)
|
|
x += res
|
|
return x
|
|
|
|
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
|
|
layers = [
|
|
CausualConv(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
kernel_size=3,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
),
|
|
_get_activation_fn(activ),
|
|
nn.BatchNorm1d(hidden_dim),
|
|
nn.Dropout(p=dropout_p),
|
|
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
|
_get_activation_fn(activ),
|
|
nn.Dropout(p=dropout_p),
|
|
]
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
|
|
super().__init__()
|
|
self._n_groups = 8
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
self._get_conv(
|
|
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
|
)
|
|
for i in range(n_conv)
|
|
]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for block in self.blocks:
|
|
res = x
|
|
x = block(x)
|
|
x += res
|
|
return x
|
|
|
|
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
|
|
layers = [
|
|
ConvNorm(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
kernel_size=3,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
),
|
|
_get_activation_fn(activ),
|
|
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
|
nn.Dropout(p=dropout_p),
|
|
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
|
_get_activation_fn(activ),
|
|
nn.Dropout(p=dropout_p),
|
|
]
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
class LocationLayer(nn.Module):
|
|
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
|
|
super(LocationLayer, self).__init__()
|
|
padding = int((attention_kernel_size - 1) / 2)
|
|
self.location_conv = ConvNorm(
|
|
2,
|
|
attention_n_filters,
|
|
kernel_size=attention_kernel_size,
|
|
padding=padding,
|
|
bias=False,
|
|
stride=1,
|
|
dilation=1,
|
|
)
|
|
self.location_dense = LinearNorm(
|
|
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
|
|
)
|
|
|
|
def forward(self, attention_weights_cat):
|
|
processed_attention = self.location_conv(attention_weights_cat)
|
|
processed_attention = processed_attention.transpose(1, 2)
|
|
processed_attention = self.location_dense(processed_attention)
|
|
return processed_attention
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
attention_rnn_dim,
|
|
embedding_dim,
|
|
attention_dim,
|
|
attention_location_n_filters,
|
|
attention_location_kernel_size,
|
|
):
|
|
super(Attention, self).__init__()
|
|
self.query_layer = LinearNorm(
|
|
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
|
)
|
|
self.memory_layer = LinearNorm(
|
|
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
|
)
|
|
self.v = LinearNorm(attention_dim, 1, bias=False)
|
|
self.location_layer = LocationLayer(
|
|
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
|
)
|
|
self.score_mask_value = -float("inf")
|
|
|
|
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
|
"""
|
|
PARAMS
|
|
------
|
|
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
|
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
|
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
|
RETURNS
|
|
-------
|
|
alignment (batch, max_time)
|
|
"""
|
|
|
|
processed_query = self.query_layer(query.unsqueeze(1))
|
|
processed_attention_weights = self.location_layer(attention_weights_cat)
|
|
energies = self.v(
|
|
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
|
)
|
|
|
|
energies = energies.squeeze(-1)
|
|
return energies
|
|
|
|
def forward(
|
|
self,
|
|
attention_hidden_state,
|
|
memory,
|
|
processed_memory,
|
|
attention_weights_cat,
|
|
mask,
|
|
):
|
|
"""
|
|
PARAMS
|
|
------
|
|
attention_hidden_state: attention rnn last output
|
|
memory: encoder outputs
|
|
processed_memory: processed encoder outputs
|
|
attention_weights_cat: previous and cummulative attention weights
|
|
mask: binary mask for padded data
|
|
"""
|
|
alignment = self.get_alignment_energies(
|
|
attention_hidden_state, processed_memory, attention_weights_cat
|
|
)
|
|
|
|
if mask is not None:
|
|
alignment.data.masked_fill_(mask, self.score_mask_value)
|
|
|
|
attention_weights = F.softmax(alignment, dim=1)
|
|
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
|
attention_context = attention_context.squeeze(1)
|
|
|
|
return attention_context, attention_weights
|
|
|
|
|
|
class ForwardAttentionV2(nn.Module):
|
|
def __init__(
|
|
self,
|
|
attention_rnn_dim,
|
|
embedding_dim,
|
|
attention_dim,
|
|
attention_location_n_filters,
|
|
attention_location_kernel_size,
|
|
):
|
|
super(ForwardAttentionV2, self).__init__()
|
|
self.query_layer = LinearNorm(
|
|
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
|
)
|
|
self.memory_layer = LinearNorm(
|
|
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
|
)
|
|
self.v = LinearNorm(attention_dim, 1, bias=False)
|
|
self.location_layer = LocationLayer(
|
|
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
|
)
|
|
self.score_mask_value = -float(1e20)
|
|
|
|
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
|
"""
|
|
PARAMS
|
|
------
|
|
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
|
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
|
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
|
|
RETURNS
|
|
-------
|
|
alignment (batch, max_time)
|
|
"""
|
|
|
|
processed_query = self.query_layer(query.unsqueeze(1))
|
|
processed_attention_weights = self.location_layer(attention_weights_cat)
|
|
energies = self.v(
|
|
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
|
)
|
|
|
|
energies = energies.squeeze(-1)
|
|
return energies
|
|
|
|
def forward(
|
|
self,
|
|
attention_hidden_state,
|
|
memory,
|
|
processed_memory,
|
|
attention_weights_cat,
|
|
mask,
|
|
log_alpha,
|
|
):
|
|
"""
|
|
PARAMS
|
|
------
|
|
attention_hidden_state: attention rnn last output
|
|
memory: encoder outputs
|
|
processed_memory: processed encoder outputs
|
|
attention_weights_cat: previous and cummulative attention weights
|
|
mask: binary mask for padded data
|
|
"""
|
|
log_energy = self.get_alignment_energies(
|
|
attention_hidden_state, processed_memory, attention_weights_cat
|
|
)
|
|
|
|
# log_energy =
|
|
|
|
if mask is not None:
|
|
log_energy.data.masked_fill_(mask, self.score_mask_value)
|
|
|
|
# attention_weights = F.softmax(alignment, dim=1)
|
|
|
|
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
|
|
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
|
|
|
|
# log_total_score = log_alpha + content_score
|
|
|
|
# previous_attention_weights = attention_weights_cat[:,0,:]
|
|
|
|
log_alpha_shift_padded = []
|
|
max_time = log_energy.size(1)
|
|
for sft in range(2):
|
|
shifted = log_alpha[:, : max_time - sft]
|
|
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
|
|
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
|
|
|
|
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
|
|
|
|
log_alpha_new = biased + log_energy
|
|
|
|
attention_weights = F.softmax(log_alpha_new, dim=1)
|
|
|
|
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
|
attention_context = attention_context.squeeze(1)
|
|
|
|
return attention_context, attention_weights, log_alpha_new
|
|
|
|
|
|
class PhaseShuffle2d(nn.Module):
|
|
def __init__(self, n=2):
|
|
super(PhaseShuffle2d, self).__init__()
|
|
self.n = n
|
|
self.random = random.Random(1)
|
|
|
|
def forward(self, x, move=None):
|
|
# x.size = (B, C, M, L)
|
|
if move is None:
|
|
move = self.random.randint(-self.n, self.n)
|
|
|
|
if move == 0:
|
|
return x
|
|
else:
|
|
left = x[:, :, :, :move]
|
|
right = x[:, :, :, move:]
|
|
shuffled = torch.cat([right, left], dim=3)
|
|
return shuffled
|
|
|
|
|
|
class PhaseShuffle1d(nn.Module):
|
|
def __init__(self, n=2):
|
|
super(PhaseShuffle1d, self).__init__()
|
|
self.n = n
|
|
self.random = random.Random(1)
|
|
|
|
def forward(self, x, move=None):
|
|
# x.size = (B, C, M, L)
|
|
if move is None:
|
|
move = self.random.randint(-self.n, self.n)
|
|
|
|
if move == 0:
|
|
return x
|
|
else:
|
|
left = x[:, :, :move]
|
|
right = x[:, :, move:]
|
|
shuffled = torch.cat([right, left], dim=2)
|
|
|
|
return shuffled
|
|
|
|
|
|
class MFCC(nn.Module):
|
|
def __init__(self, n_mfcc=40, n_mels=80):
|
|
super(MFCC, self).__init__()
|
|
self.n_mfcc = n_mfcc
|
|
self.n_mels = n_mels
|
|
self.norm = "ortho"
|
|
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
|
self.register_buffer("dct_mat", dct_mat)
|
|
|
|
def forward(self, mel_specgram):
|
|
if len(mel_specgram.shape) == 2:
|
|
mel_specgram = mel_specgram.unsqueeze(0)
|
|
unsqueezed = True
|
|
else:
|
|
unsqueezed = False
|
|
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
|
# -> (channel, time, n_mfcc).tranpose(...)
|
|
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
|
|
|
# unpack batch
|
|
if unsqueezed:
|
|
mfcc = mfcc.squeeze(0)
|
|
return mfcc
|