mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
427 lines
13 KiB
Python
427 lines
13 KiB
Python
# Copyright (c) 2024 Amphion.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch.nn.utils import weight_norm
|
|
|
|
from indextts.utils.maskgct.models.codec.amphion_codec.quantize import (
|
|
ResidualVQ,
|
|
VectorQuantize,
|
|
FactorizedVectorQuantize,
|
|
LookupFreeQuantize,
|
|
)
|
|
|
|
from indextts.utils.maskgct.models.codec.amphion_codec.vocos import Vocos
|
|
|
|
|
|
def WNConv1d(*args, **kwargs):
|
|
return weight_norm(nn.Conv1d(*args, **kwargs))
|
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs):
|
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
|
|
|
|
|
# Scripting this brings model speed up 1.4x
|
|
@torch.jit.script
|
|
def snake(x, alpha):
|
|
shape = x.shape
|
|
x = x.reshape(shape[0], shape[1], -1)
|
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
|
x = x.reshape(shape)
|
|
return x
|
|
|
|
|
|
class Snake1d(nn.Module):
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
|
|
|
def forward(self, x):
|
|
return snake(x, self.alpha)
|
|
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
class ResidualUnit(nn.Module):
|
|
def __init__(self, dim: int = 16, dilation: int = 1):
|
|
super().__init__()
|
|
pad = ((7 - 1) * dilation) // 2
|
|
self.block = nn.Sequential(
|
|
Snake1d(dim),
|
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
|
Snake1d(dim),
|
|
WNConv1d(dim, dim, kernel_size=1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
y = self.block(x)
|
|
pad = (x.shape[-1] - y.shape[-1]) // 2
|
|
if pad > 0:
|
|
x = x[..., pad:-pad]
|
|
return x + y
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
def __init__(self, dim: int = 16, stride: int = 1):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
ResidualUnit(dim // 2, dilation=1),
|
|
ResidualUnit(dim // 2, dilation=3),
|
|
ResidualUnit(dim // 2, dilation=9),
|
|
Snake1d(dim // 2),
|
|
WNConv1d(
|
|
dim // 2,
|
|
dim,
|
|
kernel_size=2 * stride,
|
|
stride=stride,
|
|
padding=math.ceil(stride / 2),
|
|
),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class CodecEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model: int = 64,
|
|
up_ratios: list = [4, 5, 5, 6],
|
|
out_channels: int = 256,
|
|
use_tanh: bool = False,
|
|
cfg=None,
|
|
):
|
|
super().__init__()
|
|
|
|
d_model = cfg.d_model if cfg is not None else d_model
|
|
up_ratios = cfg.up_ratios if cfg is not None else up_ratios
|
|
out_channels = cfg.out_channels if cfg is not None else out_channels
|
|
use_tanh = cfg.use_tanh if cfg is not None else use_tanh
|
|
|
|
# Create first convolution
|
|
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
|
|
|
# Create EncoderBlocks that double channels as they downsample by `stride`
|
|
for stride in up_ratios:
|
|
d_model *= 2
|
|
self.block += [EncoderBlock(d_model, stride=stride)]
|
|
|
|
# Create last convolution
|
|
self.block += [
|
|
Snake1d(d_model),
|
|
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
|
]
|
|
|
|
if use_tanh:
|
|
self.block += [nn.Tanh()]
|
|
|
|
# Wrap black into nn.Sequential
|
|
self.block = nn.Sequential(*self.block)
|
|
self.enc_dim = d_model
|
|
|
|
self.reset_parameters()
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
def reset_parameters(self):
|
|
self.apply(init_weights)
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
|
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
Snake1d(input_dim),
|
|
WNConvTranspose1d(
|
|
input_dim,
|
|
output_dim,
|
|
kernel_size=2 * stride,
|
|
stride=stride,
|
|
padding=stride // 2 + stride % 2,
|
|
output_padding=stride % 2,
|
|
),
|
|
ResidualUnit(output_dim, dilation=1),
|
|
ResidualUnit(output_dim, dilation=3),
|
|
ResidualUnit(output_dim, dilation=9),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class CodecDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 256,
|
|
upsample_initial_channel: int = 1536,
|
|
up_ratios: list = [5, 5, 4, 2],
|
|
num_quantizers: int = 8,
|
|
codebook_size: int = 1024,
|
|
codebook_dim: int = 256,
|
|
quantizer_type: str = "vq",
|
|
quantizer_dropout: float = 0.5,
|
|
commitment: float = 0.25,
|
|
codebook_loss_weight: float = 1.0,
|
|
use_l2_normlize: bool = False,
|
|
codebook_type: str = "euclidean",
|
|
kmeans_init: bool = False,
|
|
kmeans_iters: int = 10,
|
|
decay: float = 0.8,
|
|
eps: float = 1e-5,
|
|
threshold_ema_dead_code: int = 2,
|
|
weight_init: bool = False,
|
|
use_vocos: bool = False,
|
|
vocos_dim: int = 384,
|
|
vocos_intermediate_dim: int = 1152,
|
|
vocos_num_layers: int = 8,
|
|
n_fft: int = 800,
|
|
hop_size: int = 200,
|
|
padding: str = "same",
|
|
cfg=None,
|
|
):
|
|
super().__init__()
|
|
|
|
in_channels = (
|
|
cfg.in_channels
|
|
if cfg is not None and hasattr(cfg, "in_channels")
|
|
else in_channels
|
|
)
|
|
upsample_initial_channel = (
|
|
cfg.upsample_initial_channel
|
|
if cfg is not None and hasattr(cfg, "upsample_initial_channel")
|
|
else upsample_initial_channel
|
|
)
|
|
up_ratios = (
|
|
cfg.up_ratios
|
|
if cfg is not None and hasattr(cfg, "up_ratios")
|
|
else up_ratios
|
|
)
|
|
num_quantizers = (
|
|
cfg.num_quantizers
|
|
if cfg is not None and hasattr(cfg, "num_quantizers")
|
|
else num_quantizers
|
|
)
|
|
codebook_size = (
|
|
cfg.codebook_size
|
|
if cfg is not None and hasattr(cfg, "codebook_size")
|
|
else codebook_size
|
|
)
|
|
codebook_dim = (
|
|
cfg.codebook_dim
|
|
if cfg is not None and hasattr(cfg, "codebook_dim")
|
|
else codebook_dim
|
|
)
|
|
quantizer_type = (
|
|
cfg.quantizer_type
|
|
if cfg is not None and hasattr(cfg, "quantizer_type")
|
|
else quantizer_type
|
|
)
|
|
quantizer_dropout = (
|
|
cfg.quantizer_dropout
|
|
if cfg is not None and hasattr(cfg, "quantizer_dropout")
|
|
else quantizer_dropout
|
|
)
|
|
commitment = (
|
|
cfg.commitment
|
|
if cfg is not None and hasattr(cfg, "commitment")
|
|
else commitment
|
|
)
|
|
codebook_loss_weight = (
|
|
cfg.codebook_loss_weight
|
|
if cfg is not None and hasattr(cfg, "codebook_loss_weight")
|
|
else codebook_loss_weight
|
|
)
|
|
use_l2_normlize = (
|
|
cfg.use_l2_normlize
|
|
if cfg is not None and hasattr(cfg, "use_l2_normlize")
|
|
else use_l2_normlize
|
|
)
|
|
codebook_type = (
|
|
cfg.codebook_type
|
|
if cfg is not None and hasattr(cfg, "codebook_type")
|
|
else codebook_type
|
|
)
|
|
kmeans_init = (
|
|
cfg.kmeans_init
|
|
if cfg is not None and hasattr(cfg, "kmeans_init")
|
|
else kmeans_init
|
|
)
|
|
kmeans_iters = (
|
|
cfg.kmeans_iters
|
|
if cfg is not None and hasattr(cfg, "kmeans_iters")
|
|
else kmeans_iters
|
|
)
|
|
decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
|
|
eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
|
|
threshold_ema_dead_code = (
|
|
cfg.threshold_ema_dead_code
|
|
if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
|
|
else threshold_ema_dead_code
|
|
)
|
|
weight_init = (
|
|
cfg.weight_init
|
|
if cfg is not None and hasattr(cfg, "weight_init")
|
|
else weight_init
|
|
)
|
|
use_vocos = (
|
|
cfg.use_vocos
|
|
if cfg is not None and hasattr(cfg, "use_vocos")
|
|
else use_vocos
|
|
)
|
|
vocos_dim = (
|
|
cfg.vocos_dim
|
|
if cfg is not None and hasattr(cfg, "vocos_dim")
|
|
else vocos_dim
|
|
)
|
|
vocos_intermediate_dim = (
|
|
cfg.vocos_intermediate_dim
|
|
if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
|
|
else vocos_intermediate_dim
|
|
)
|
|
vocos_num_layers = (
|
|
cfg.vocos_num_layers
|
|
if cfg is not None and hasattr(cfg, "vocos_num_layers")
|
|
else vocos_num_layers
|
|
)
|
|
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
|
hop_size = (
|
|
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
|
)
|
|
padding = (
|
|
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
|
)
|
|
|
|
if quantizer_type == "vq":
|
|
self.quantizer = ResidualVQ(
|
|
input_dim=in_channels,
|
|
num_quantizers=num_quantizers,
|
|
codebook_size=codebook_size,
|
|
codebook_dim=codebook_dim,
|
|
quantizer_type=quantizer_type,
|
|
quantizer_dropout=quantizer_dropout,
|
|
commitment=commitment,
|
|
codebook_loss_weight=codebook_loss_weight,
|
|
use_l2_normlize=use_l2_normlize,
|
|
codebook_type=codebook_type,
|
|
kmeans_init=kmeans_init,
|
|
kmeans_iters=kmeans_iters,
|
|
decay=decay,
|
|
eps=eps,
|
|
threshold_ema_dead_code=threshold_ema_dead_code,
|
|
weight_init=weight_init,
|
|
)
|
|
elif quantizer_type == "fvq":
|
|
self.quantizer = ResidualVQ(
|
|
input_dim=in_channels,
|
|
num_quantizers=num_quantizers,
|
|
codebook_size=codebook_size,
|
|
codebook_dim=codebook_dim,
|
|
quantizer_type=quantizer_type,
|
|
quantizer_dropout=quantizer_dropout,
|
|
commitment=commitment,
|
|
codebook_loss_weight=codebook_loss_weight,
|
|
use_l2_normlize=use_l2_normlize,
|
|
)
|
|
elif quantizer_type == "lfq":
|
|
self.quantizer = ResidualVQ(
|
|
input_dim=in_channels,
|
|
num_quantizers=num_quantizers,
|
|
codebook_size=codebook_size,
|
|
codebook_dim=codebook_dim,
|
|
quantizer_type=quantizer_type,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
|
|
|
if not use_vocos:
|
|
# Add first conv layer
|
|
channels = upsample_initial_channel
|
|
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
|
|
|
# Add upsampling + MRF blocks
|
|
for i, stride in enumerate(up_ratios):
|
|
input_dim = channels // 2**i
|
|
output_dim = channels // 2 ** (i + 1)
|
|
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
|
|
|
# Add final conv layer
|
|
layers += [
|
|
Snake1d(output_dim),
|
|
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
|
nn.Tanh(),
|
|
]
|
|
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
if use_vocos:
|
|
self.model = Vocos(
|
|
input_channels=in_channels,
|
|
dim=vocos_dim,
|
|
intermediate_dim=vocos_intermediate_dim,
|
|
num_layers=vocos_num_layers,
|
|
adanorm_num_embeddings=None,
|
|
n_fft=n_fft,
|
|
hop_size=hop_size,
|
|
padding=padding,
|
|
)
|
|
|
|
self.reset_parameters()
|
|
|
|
def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
|
|
"""
|
|
if vq is True, x = encoder output, then return quantized output;
|
|
else, x = quantized output, then return decoder output
|
|
"""
|
|
if vq is True:
|
|
if eval_vq:
|
|
self.quantizer.eval()
|
|
(
|
|
quantized_out,
|
|
all_indices,
|
|
all_commit_losses,
|
|
all_codebook_losses,
|
|
all_quantized,
|
|
) = self.quantizer(x, n_quantizers=n_quantizers)
|
|
return (
|
|
quantized_out,
|
|
all_indices,
|
|
all_commit_losses,
|
|
all_codebook_losses,
|
|
all_quantized,
|
|
)
|
|
|
|
return self.model(x)
|
|
|
|
def quantize(self, x, n_quantizers=None):
|
|
self.quantizer.eval()
|
|
quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
|
|
return quantized_out, vq
|
|
|
|
# TODO: check consistency of vq2emb and quantize
|
|
def vq2emb(self, vq, n_quantizers=None):
|
|
return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
|
|
|
|
def decode(self, x):
|
|
return self.model(x)
|
|
|
|
def latent2dist(self, x, n_quantizers=None):
|
|
return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
|
|
|
|
def reset_parameters(self):
|
|
self.apply(init_weights)
|