mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 10:20:24 +08:00
520 lines
20 KiB
Python
520 lines
20 KiB
Python
|
||
from typing import Optional, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
|
||
RelPositionMultiHeadedAttention)
|
||
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
|
||
PositionalEncoding,
|
||
RelPositionalEncoding)
|
||
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
|
||
Conv2dSubsampling4,
|
||
Conv2dSubsampling6,
|
||
Conv2dSubsampling8,
|
||
LinearNoSubsampling)
|
||
from indextts.utils.common import make_pad_mask
|
||
|
||
|
||
class PositionwiseFeedForward(torch.nn.Module):
|
||
"""Positionwise feed forward layer.
|
||
|
||
FeedForward are appied on each position of the sequence.
|
||
The output dim is same with the input dim.
|
||
|
||
Args:
|
||
idim (int): Input dimenstion.
|
||
hidden_units (int): The number of hidden units.
|
||
dropout_rate (float): Dropout rate.
|
||
activation (torch.nn.Module): Activation function
|
||
"""
|
||
|
||
def __init__(self,
|
||
idim: int,
|
||
hidden_units: int,
|
||
dropout_rate: float,
|
||
activation: torch.nn.Module = torch.nn.ReLU()):
|
||
"""Construct a PositionwiseFeedForward object."""
|
||
super(PositionwiseFeedForward, self).__init__()
|
||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||
self.activation = activation
|
||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||
|
||
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
||
"""Forward function.
|
||
|
||
Args:
|
||
xs: input tensor (B, L, D)
|
||
Returns:
|
||
output tensor, (B, L, D)
|
||
"""
|
||
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
||
|
||
|
||
class ConvolutionModule(nn.Module):
|
||
"""ConvolutionModule in Conformer model."""
|
||
|
||
def __init__(self,
|
||
channels: int,
|
||
kernel_size: int = 15,
|
||
activation: nn.Module = nn.ReLU(),
|
||
bias: bool = True):
|
||
"""Construct an ConvolutionModule object.
|
||
Args:
|
||
channels (int): The number of channels of conv layers.
|
||
kernel_size (int): Kernel size of conv layers.
|
||
causal (int): Whether use causal convolution or not
|
||
"""
|
||
super().__init__()
|
||
|
||
self.pointwise_conv1 = nn.Conv1d(
|
||
channels,
|
||
2 * channels,
|
||
kernel_size=1,
|
||
stride=1,
|
||
padding=0,
|
||
bias=bias,
|
||
)
|
||
# self.lorder is used to distinguish if it's a causal convolution,
|
||
# if self.lorder > 0: it's a causal convolution, the input will be
|
||
# padded with self.lorder frames on the left in forward.
|
||
# else: it's a symmetrical convolution
|
||
# kernel_size should be an odd number for none causal convolution
|
||
assert (kernel_size - 1) % 2 == 0
|
||
padding = (kernel_size - 1) // 2
|
||
self.lorder = 0
|
||
|
||
self.depthwise_conv = nn.Conv1d(
|
||
channels,
|
||
channels,
|
||
kernel_size,
|
||
stride=1,
|
||
padding=padding,
|
||
groups=channels,
|
||
bias=bias,
|
||
)
|
||
|
||
self.use_layer_norm = True
|
||
self.norm = nn.LayerNorm(channels)
|
||
|
||
self.pointwise_conv2 = nn.Conv1d(
|
||
channels,
|
||
channels,
|
||
kernel_size=1,
|
||
stride=1,
|
||
padding=0,
|
||
bias=bias,
|
||
)
|
||
self.activation = activation
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Compute convolution module.
|
||
Args:
|
||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
||
(0, 0, 0) means fake mask.
|
||
cache (torch.Tensor): left context cache, it is only
|
||
used in causal convolution (#batch, channels, cache_t),
|
||
(0, 0, 0) meas fake cache.
|
||
Returns:
|
||
torch.Tensor: Output tensor (#batch, time, channels).
|
||
"""
|
||
# exchange the temporal dimension and the feature dimension
|
||
x = x.transpose(1, 2) # (#batch, channels, time)
|
||
|
||
# mask batch padding
|
||
if mask_pad.size(2) > 0: # time > 0
|
||
x.masked_fill_(~mask_pad, 0.0)
|
||
|
||
if self.lorder > 0:
|
||
if cache.size(2) == 0: # cache_t == 0
|
||
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
||
else:
|
||
assert cache.size(0) == x.size(0) # equal batch
|
||
assert cache.size(1) == x.size(1) # equal channel
|
||
x = torch.cat((cache, x), dim=2)
|
||
assert (x.size(2) > self.lorder)
|
||
new_cache = x[:, :, -self.lorder:]
|
||
else:
|
||
# It's better we just return None if no cache is required,
|
||
# However, for JIT export, here we just fake one tensor instead of
|
||
# None.
|
||
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
||
|
||
# GLU mechanism
|
||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||
|
||
# 1D Depthwise Conv
|
||
x = self.depthwise_conv(x)
|
||
if self.use_layer_norm:
|
||
x = x.transpose(1, 2)
|
||
x = self.activation(self.norm(x))
|
||
if self.use_layer_norm:
|
||
x = x.transpose(1, 2)
|
||
x = self.pointwise_conv2(x)
|
||
# mask batch padding
|
||
if mask_pad.size(2) > 0: # time > 0
|
||
x.masked_fill_(~mask_pad, 0.0)
|
||
|
||
return x.transpose(1, 2), new_cache
|
||
|
||
|
||
class ConformerEncoderLayer(nn.Module):
|
||
"""Encoder layer module.
|
||
Args:
|
||
size (int): Input dimension.
|
||
self_attn (torch.nn.Module): Self-attention module instance.
|
||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
||
instance can be used as the argument.
|
||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||
`PositionwiseFeedForward` instance can be used as the argument.
|
||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
||
instance.
|
||
`PositionwiseFeedForward` instance can be used as the argument.
|
||
conv_module (torch.nn.Module): Convolution module instance.
|
||
`ConvlutionModule` instance can be used as the argument.
|
||
dropout_rate (float): Dropout rate.
|
||
normalize_before (bool):
|
||
True: use layer_norm before each sub-block.
|
||
False: use layer_norm after each sub-block.
|
||
concat_after (bool): Whether to concat attention layer's input and
|
||
output.
|
||
True: x -> x + linear(concat(x, att(x)))
|
||
False: x -> x + att(x)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
size: int,
|
||
self_attn: torch.nn.Module,
|
||
feed_forward: Optional[nn.Module] = None,
|
||
feed_forward_macaron: Optional[nn.Module] = None,
|
||
conv_module: Optional[nn.Module] = None,
|
||
dropout_rate: float = 0.1,
|
||
normalize_before: bool = True,
|
||
concat_after: bool = False,
|
||
):
|
||
"""Construct an EncoderLayer object."""
|
||
super().__init__()
|
||
self.self_attn = self_attn
|
||
self.feed_forward = feed_forward
|
||
self.feed_forward_macaron = feed_forward_macaron
|
||
self.conv_module = conv_module
|
||
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
||
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
||
if feed_forward_macaron is not None:
|
||
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
||
self.ff_scale = 0.5
|
||
else:
|
||
self.ff_scale = 1.0
|
||
if self.conv_module is not None:
|
||
self.norm_conv = nn.LayerNorm(size,
|
||
eps=1e-5) # for the CNN module
|
||
self.norm_final = nn.LayerNorm(
|
||
size, eps=1e-5) # for the final output of the block
|
||
self.dropout = nn.Dropout(dropout_rate)
|
||
self.size = size
|
||
self.normalize_before = normalize_before
|
||
self.concat_after = concat_after
|
||
if self.concat_after:
|
||
self.concat_linear = nn.Linear(size + size, size)
|
||
else:
|
||
self.concat_linear = nn.Identity()
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
mask: torch.Tensor,
|
||
pos_emb: torch.Tensor,
|
||
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
||
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Compute encoded features.
|
||
|
||
Args:
|
||
x (torch.Tensor): (#batch, time, size)
|
||
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
||
(0, 0, 0) means fake mask.
|
||
pos_emb (torch.Tensor): positional encoding, must not be None
|
||
for ConformerEncoderLayer.
|
||
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
||
(#batch, 1,time), (0, 0, 0) means fake mask.
|
||
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
||
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
||
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
||
(#batch=1, size, cache_t2)
|
||
Returns:
|
||
torch.Tensor: Output tensor (#batch, time, size).
|
||
torch.Tensor: Mask tensor (#batch, time, time).
|
||
torch.Tensor: att_cache tensor,
|
||
(#batch=1, head, cache_t1 + time, d_k * 2).
|
||
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
||
"""
|
||
|
||
# whether to use macaron style
|
||
if self.feed_forward_macaron is not None:
|
||
residual = x
|
||
if self.normalize_before:
|
||
x = self.norm_ff_macaron(x)
|
||
x = residual + self.ff_scale * self.dropout(
|
||
self.feed_forward_macaron(x))
|
||
if not self.normalize_before:
|
||
x = self.norm_ff_macaron(x)
|
||
|
||
# multi-headed self-attention module
|
||
residual = x
|
||
if self.normalize_before:
|
||
x = self.norm_mha(x)
|
||
|
||
x_att, new_att_cache = self.self_attn(
|
||
x, x, x, mask, pos_emb, att_cache)
|
||
if self.concat_after:
|
||
x_concat = torch.cat((x, x_att), dim=-1)
|
||
x = residual + self.concat_linear(x_concat)
|
||
else:
|
||
x = residual + self.dropout(x_att)
|
||
if not self.normalize_before:
|
||
x = self.norm_mha(x)
|
||
|
||
# convolution module
|
||
# Fake new cnn cache here, and then change it in conv_module
|
||
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
||
if self.conv_module is not None:
|
||
residual = x
|
||
if self.normalize_before:
|
||
x = self.norm_conv(x)
|
||
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
||
x = residual + self.dropout(x)
|
||
|
||
if not self.normalize_before:
|
||
x = self.norm_conv(x)
|
||
|
||
# feed forward module
|
||
residual = x
|
||
if self.normalize_before:
|
||
x = self.norm_ff(x)
|
||
|
||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||
if not self.normalize_before:
|
||
x = self.norm_ff(x)
|
||
|
||
if self.conv_module is not None:
|
||
x = self.norm_final(x)
|
||
|
||
return x, mask, new_att_cache, new_cnn_cache
|
||
|
||
|
||
class BaseEncoder(torch.nn.Module):
|
||
def __init__(
|
||
self,
|
||
input_size: int,
|
||
output_size: int = 256,
|
||
attention_heads: int = 4,
|
||
linear_units: int = 2048,
|
||
num_blocks: int = 6,
|
||
dropout_rate: float = 0.0,
|
||
input_layer: str = "conv2d",
|
||
pos_enc_layer_type: str = "abs_pos",
|
||
normalize_before: bool = True,
|
||
concat_after: bool = False,
|
||
):
|
||
"""
|
||
Args:
|
||
input_size (int): input dim
|
||
output_size (int): dimension of attention
|
||
attention_heads (int): the number of heads of multi head attention
|
||
linear_units (int): the hidden units number of position-wise feed
|
||
forward
|
||
num_blocks (int): the number of decoder blocks
|
||
dropout_rate (float): dropout rate
|
||
attention_dropout_rate (float): dropout rate in attention
|
||
positional_dropout_rate (float): dropout rate after adding
|
||
positional encoding
|
||
input_layer (str): input layer type.
|
||
optional [linear, conv2d, conv2d6, conv2d8]
|
||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
||
normalize_before (bool):
|
||
True: use layer_norm before each sub-block of a layer.
|
||
False: use layer_norm after each sub-block of a layer.
|
||
concat_after (bool): whether to concat attention layer's input
|
||
and output.
|
||
True: x -> x + linear(concat(x, att(x)))
|
||
False: x -> x + att(x)
|
||
static_chunk_size (int): chunk size for static chunk training and
|
||
decoding
|
||
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
||
training or not, You can only use fixed chunk(chunk_size > 0)
|
||
or dyanmic chunk size(use_dynamic_chunk = True)
|
||
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
||
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
||
dynamic chunk training
|
||
"""
|
||
super().__init__()
|
||
self._output_size = output_size
|
||
|
||
if pos_enc_layer_type == "abs_pos":
|
||
pos_enc_class = PositionalEncoding
|
||
elif pos_enc_layer_type == "rel_pos":
|
||
pos_enc_class = RelPositionalEncoding
|
||
elif pos_enc_layer_type == "no_pos":
|
||
pos_enc_class = NoPositionalEncoding
|
||
else:
|
||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||
|
||
if input_layer == "linear":
|
||
subsampling_class = LinearNoSubsampling
|
||
elif input_layer == "conv2d2":
|
||
subsampling_class = Conv2dSubsampling2
|
||
elif input_layer == "conv2d":
|
||
subsampling_class = Conv2dSubsampling4
|
||
elif input_layer == "conv2d6":
|
||
subsampling_class = Conv2dSubsampling6
|
||
elif input_layer == "conv2d8":
|
||
subsampling_class = Conv2dSubsampling8
|
||
else:
|
||
raise ValueError("unknown input_layer: " + input_layer)
|
||
|
||
self.embed = subsampling_class(
|
||
input_size,
|
||
output_size,
|
||
dropout_rate,
|
||
pos_enc_class(output_size, dropout_rate),
|
||
)
|
||
|
||
self.normalize_before = normalize_before
|
||
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
||
|
||
def output_size(self) -> int:
|
||
return self._output_size
|
||
|
||
def forward(
|
||
self,
|
||
xs: torch.Tensor,
|
||
xs_lens: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Embed positions in tensor.
|
||
|
||
Args:
|
||
xs: padded input tensor (B, T, D)
|
||
xs_lens: input length (B)
|
||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||
0: default for training, use random dynamic chunk.
|
||
<0: for decoding, use full chunk.
|
||
>0: for decoding, use fixed chunk size as set.
|
||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||
the chunk size is decoding_chunk_size.
|
||
>=0: use num_decoding_left_chunks
|
||
<0: use all left chunks
|
||
Returns:
|
||
encoder output tensor xs, and subsampled masks
|
||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||
masks: torch.Tensor batch padding mask after subsample
|
||
(B, 1, T' ~= T/subsample_rate)
|
||
"""
|
||
T = xs.size(1)
|
||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||
xs, pos_emb, masks = self.embed(xs, masks)
|
||
chunk_masks = masks
|
||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||
for layer in self.encoders:
|
||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||
if self.normalize_before:
|
||
xs = self.after_norm(xs)
|
||
# Here we assume the mask is not changed in encoder layers, so just
|
||
# return the masks before encoder layers, and the masks will be used
|
||
# for cross attention with decoder later
|
||
return xs, masks
|
||
|
||
|
||
class ConformerEncoder(BaseEncoder):
|
||
"""Conformer encoder module."""
|
||
|
||
def __init__(
|
||
self,
|
||
input_size: int,
|
||
output_size: int = 256,
|
||
attention_heads: int = 4,
|
||
linear_units: int = 2048,
|
||
num_blocks: int = 6,
|
||
dropout_rate: float = 0.0,
|
||
input_layer: str = "conv2d",
|
||
pos_enc_layer_type: str = "rel_pos",
|
||
normalize_before: bool = True,
|
||
concat_after: bool = False,
|
||
macaron_style: bool = False,
|
||
use_cnn_module: bool = True,
|
||
cnn_module_kernel: int = 15,
|
||
):
|
||
"""Construct ConformerEncoder
|
||
|
||
Args:
|
||
input_size to use_dynamic_chunk, see in BaseEncoder
|
||
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
||
conv1d layer.
|
||
macaron_style (bool): Whether to use macaron style for
|
||
positionwise layer.
|
||
selfattention_layer_type (str): Encoder attention layer type,
|
||
the parameter has no effect now, it's just for configure
|
||
compatibility.
|
||
activation_type (str): Encoder activation function type.
|
||
use_cnn_module (bool): Whether to use convolution module.
|
||
cnn_module_kernel (int): Kernel size of convolution module.
|
||
causal (bool): whether to use causal convolution or not.
|
||
"""
|
||
|
||
super().__init__(input_size, output_size, attention_heads,
|
||
linear_units, num_blocks, dropout_rate,
|
||
input_layer, pos_enc_layer_type, normalize_before,
|
||
concat_after)
|
||
|
||
activation = torch.nn.SiLU()
|
||
|
||
# self-attention module definition
|
||
if pos_enc_layer_type != "rel_pos":
|
||
encoder_selfattn_layer = MultiHeadedAttention
|
||
else:
|
||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||
encoder_selfattn_layer_args = (
|
||
attention_heads,
|
||
output_size,
|
||
dropout_rate,
|
||
)
|
||
|
||
# feed-forward module definition
|
||
positionwise_layer = PositionwiseFeedForward
|
||
positionwise_layer_args = (
|
||
output_size,
|
||
linear_units,
|
||
dropout_rate,
|
||
activation,
|
||
)
|
||
# convolution module definition
|
||
convolution_layer = ConvolutionModule
|
||
convolution_layer_args = (output_size,
|
||
cnn_module_kernel,
|
||
activation,)
|
||
|
||
self.encoders = torch.nn.ModuleList([
|
||
ConformerEncoderLayer(
|
||
output_size,
|
||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||
positionwise_layer(*positionwise_layer_args),
|
||
positionwise_layer(
|
||
*positionwise_layer_args) if macaron_style else None,
|
||
convolution_layer(
|
||
*convolution_layer_args) if use_cnn_module else None,
|
||
dropout_rate,
|
||
normalize_before,
|
||
concat_after,
|
||
) for _ in range(num_blocks)
|
||
])
|