mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
234 lines
7 KiB
Python
234 lines
7 KiB
Python
# Copyright (c) 2023 Amphion.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class StyleAdaptiveLayerNorm(nn.Module):
|
|
def __init__(self, normalized_shape, eps=1e-5):
|
|
super().__init__()
|
|
self.in_dim = normalized_shape
|
|
self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
|
|
self.style = nn.Linear(self.in_dim, self.in_dim * 2)
|
|
self.style.bias.data[: self.in_dim] = 1
|
|
self.style.bias.data[self.in_dim :] = 0
|
|
|
|
def forward(self, x, condition):
|
|
# x: (B, T, d); condition: (B, T, d)
|
|
|
|
style = self.style(torch.mean(condition, dim=1, keepdim=True))
|
|
|
|
gamma, beta = style.chunk(2, -1)
|
|
|
|
out = self.norm(x)
|
|
|
|
out = gamma * out + beta
|
|
return out
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, d_model, dropout, max_len=5000):
|
|
super().__init__()
|
|
|
|
self.dropout = dropout
|
|
position = torch.arange(max_len).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
|
)
|
|
pe = torch.zeros(max_len, 1, d_model)
|
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
|
self.register_buffer("pe", pe)
|
|
|
|
def forward(self, x):
|
|
x = x + self.pe[: x.size(0)]
|
|
return F.dropout(x, self.dropout, training=self.training)
|
|
|
|
|
|
class TransformerFFNLayer(nn.Module):
|
|
def __init__(
|
|
self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder_hidden = encoder_hidden
|
|
self.conv_filter_size = conv_filter_size
|
|
self.conv_kernel_size = conv_kernel_size
|
|
self.encoder_dropout = encoder_dropout
|
|
|
|
self.ffn_1 = nn.Conv1d(
|
|
self.encoder_hidden,
|
|
self.conv_filter_size,
|
|
self.conv_kernel_size,
|
|
padding=self.conv_kernel_size // 2,
|
|
)
|
|
self.ffn_1.weight.data.normal_(0.0, 0.02)
|
|
self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
|
|
self.ffn_2.weight.data.normal_(0.0, 0.02)
|
|
|
|
def forward(self, x):
|
|
# x: (B, T, d)
|
|
x = self.ffn_1(x.permute(0, 2, 1)).permute(
|
|
0, 2, 1
|
|
) # (B, T, d) -> (B, d, T) -> (B, T, d)
|
|
x = F.relu(x)
|
|
x = F.dropout(x, self.encoder_dropout, training=self.training)
|
|
x = self.ffn_2(x)
|
|
return x
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
encoder_hidden,
|
|
encoder_head,
|
|
conv_filter_size,
|
|
conv_kernel_size,
|
|
encoder_dropout,
|
|
use_cln,
|
|
):
|
|
super().__init__()
|
|
self.encoder_hidden = encoder_hidden
|
|
self.encoder_head = encoder_head
|
|
self.conv_filter_size = conv_filter_size
|
|
self.conv_kernel_size = conv_kernel_size
|
|
self.encoder_dropout = encoder_dropout
|
|
self.use_cln = use_cln
|
|
|
|
if not self.use_cln:
|
|
self.ln_1 = nn.LayerNorm(self.encoder_hidden)
|
|
self.ln_2 = nn.LayerNorm(self.encoder_hidden)
|
|
else:
|
|
self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
|
self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
|
|
|
self.self_attn = nn.MultiheadAttention(
|
|
self.encoder_hidden, self.encoder_head, batch_first=True
|
|
)
|
|
|
|
self.ffn = TransformerFFNLayer(
|
|
self.encoder_hidden,
|
|
self.conv_filter_size,
|
|
self.conv_kernel_size,
|
|
self.encoder_dropout,
|
|
)
|
|
|
|
def forward(self, x, key_padding_mask, conditon=None):
|
|
# x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
|
|
|
|
# self attention
|
|
residual = x
|
|
if self.use_cln:
|
|
x = self.ln_1(x, conditon)
|
|
else:
|
|
x = self.ln_1(x)
|
|
|
|
if key_padding_mask != None:
|
|
key_padding_mask_input = ~(key_padding_mask.bool())
|
|
else:
|
|
key_padding_mask_input = None
|
|
x, _ = self.self_attn(
|
|
query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
|
|
)
|
|
x = F.dropout(x, self.encoder_dropout, training=self.training)
|
|
x = residual + x
|
|
|
|
# ffn
|
|
residual = x
|
|
if self.use_cln:
|
|
x = self.ln_2(x, conditon)
|
|
else:
|
|
x = self.ln_2(x)
|
|
x = self.ffn(x)
|
|
x = residual + x
|
|
|
|
return x
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
enc_emb_tokens=None,
|
|
encoder_layer=4,
|
|
encoder_hidden=256,
|
|
encoder_head=4,
|
|
conv_filter_size=1024,
|
|
conv_kernel_size=5,
|
|
encoder_dropout=0.1,
|
|
use_cln=False,
|
|
cfg=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder_layer = (
|
|
encoder_layer if encoder_layer is not None else cfg.encoder_layer
|
|
)
|
|
self.encoder_hidden = (
|
|
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
|
|
)
|
|
self.encoder_head = (
|
|
encoder_head if encoder_head is not None else cfg.encoder_head
|
|
)
|
|
self.conv_filter_size = (
|
|
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
|
|
)
|
|
self.conv_kernel_size = (
|
|
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
|
|
)
|
|
self.encoder_dropout = (
|
|
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
|
|
)
|
|
self.use_cln = use_cln if use_cln is not None else cfg.use_cln
|
|
|
|
if enc_emb_tokens != None:
|
|
self.use_enc_emb = True
|
|
self.enc_emb_tokens = enc_emb_tokens
|
|
else:
|
|
self.use_enc_emb = False
|
|
|
|
self.position_emb = PositionalEncoding(
|
|
self.encoder_hidden, self.encoder_dropout
|
|
)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
self.layers.extend(
|
|
[
|
|
TransformerEncoderLayer(
|
|
self.encoder_hidden,
|
|
self.encoder_head,
|
|
self.conv_filter_size,
|
|
self.conv_kernel_size,
|
|
self.encoder_dropout,
|
|
self.use_cln,
|
|
)
|
|
for i in range(self.encoder_layer)
|
|
]
|
|
)
|
|
|
|
if self.use_cln:
|
|
self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
|
else:
|
|
self.last_ln = nn.LayerNorm(self.encoder_hidden)
|
|
|
|
def forward(self, x, key_padding_mask, condition=None):
|
|
if len(x.shape) == 2 and self.use_enc_emb:
|
|
x = self.enc_emb_tokens(x)
|
|
x = self.position_emb(x)
|
|
else:
|
|
x = self.position_emb(x) # (B, T, d)
|
|
|
|
for layer in self.layers:
|
|
x = layer(x, key_padding_mask, condition)
|
|
|
|
if self.use_cln:
|
|
x = self.last_ln(x, condition)
|
|
else:
|
|
x = self.last_ln(x)
|
|
|
|
return x
|