mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
210 lines
6.4 KiB
Python
210 lines
6.4 KiB
Python
# Copyright (c) 2024 Amphion.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from concurrent.futures import ALL_COMPLETED
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from torch.nn import functional as F
|
|
from einops import rearrange, repeat
|
|
|
|
from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ResidualVQ
|
|
from indextts.utils.maskgct.models.codec.kmeans.vocos import VocosBackbone
|
|
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
def compute_codebook_perplexity(indices, codebook_size):
|
|
indices = indices.flatten()
|
|
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
|
|
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
|
|
return perp
|
|
|
|
|
|
class RepCodec(nn.Module):
|
|
def __init__(
|
|
self,
|
|
codebook_size=8192,
|
|
hidden_size=1024,
|
|
codebook_dim=8,
|
|
vocos_dim=384,
|
|
vocos_intermediate_dim=2048,
|
|
vocos_num_layers=12,
|
|
num_quantizers=1,
|
|
downsample_scale=1,
|
|
cfg=None,
|
|
):
|
|
super().__init__()
|
|
codebook_size = (
|
|
cfg.codebook_size
|
|
if cfg is not None and hasattr(cfg, "codebook_size")
|
|
else codebook_size
|
|
)
|
|
codebook_dim = (
|
|
cfg.codebook_dim
|
|
if cfg is not None and hasattr(cfg, "codebook_dim")
|
|
else codebook_dim
|
|
)
|
|
hidden_size = (
|
|
cfg.hidden_size
|
|
if cfg is not None and hasattr(cfg, "hidden_size")
|
|
else hidden_size
|
|
)
|
|
vocos_dim = (
|
|
cfg.vocos_dim
|
|
if cfg is not None and hasattr(cfg, "vocos_dim")
|
|
else vocos_dim
|
|
)
|
|
vocos_intermediate_dim = (
|
|
cfg.vocos_intermediate_dim
|
|
if cfg is not None and hasattr(cfg, "vocos_dim")
|
|
else vocos_intermediate_dim
|
|
)
|
|
vocos_num_layers = (
|
|
cfg.vocos_num_layers
|
|
if cfg is not None and hasattr(cfg, "vocos_dim")
|
|
else vocos_num_layers
|
|
)
|
|
num_quantizers = (
|
|
cfg.num_quantizers
|
|
if cfg is not None and hasattr(cfg, "num_quantizers")
|
|
else num_quantizers
|
|
)
|
|
downsample_scale = (
|
|
cfg.downsample_scale
|
|
if cfg is not None and hasattr(cfg, "downsample_scale")
|
|
else downsample_scale
|
|
)
|
|
|
|
self.codebook_size = codebook_size
|
|
self.codebook_dim = codebook_dim
|
|
self.hidden_size = hidden_size
|
|
self.vocos_dim = vocos_dim
|
|
self.vocos_intermediate_dim = vocos_intermediate_dim
|
|
self.vocos_num_layers = vocos_num_layers
|
|
self.num_quantizers = num_quantizers
|
|
self.downsample_scale = downsample_scale
|
|
|
|
if self.downsample_scale != None and self.downsample_scale > 1:
|
|
self.down = nn.Conv1d(
|
|
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
|
|
)
|
|
self.up = nn.Conv1d(
|
|
self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
self.encoder = nn.Sequential(
|
|
VocosBackbone(
|
|
input_channels=self.hidden_size,
|
|
dim=self.vocos_dim,
|
|
intermediate_dim=self.vocos_intermediate_dim,
|
|
num_layers=self.vocos_num_layers,
|
|
adanorm_num_embeddings=None,
|
|
),
|
|
nn.Linear(self.vocos_dim, self.hidden_size),
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
VocosBackbone(
|
|
input_channels=self.hidden_size,
|
|
dim=self.vocos_dim,
|
|
intermediate_dim=self.vocos_intermediate_dim,
|
|
num_layers=self.vocos_num_layers,
|
|
adanorm_num_embeddings=None,
|
|
),
|
|
nn.Linear(self.vocos_dim, self.hidden_size),
|
|
)
|
|
|
|
self.quantizer = ResidualVQ(
|
|
input_dim=hidden_size,
|
|
num_quantizers=num_quantizers,
|
|
codebook_size=codebook_size,
|
|
codebook_dim=codebook_dim,
|
|
quantizer_type="fvq",
|
|
quantizer_dropout=0.0,
|
|
commitment=0.15,
|
|
codebook_loss_weight=1.0,
|
|
use_l2_normlize=True,
|
|
)
|
|
|
|
self.reset_parameters()
|
|
|
|
def forward(self, x):
|
|
|
|
# downsample
|
|
if self.downsample_scale != None and self.downsample_scale > 1:
|
|
x = x.transpose(1, 2)
|
|
x = self.down(x)
|
|
x = F.gelu(x)
|
|
x = x.transpose(1, 2)
|
|
|
|
# encoder
|
|
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
|
|
|
# vq
|
|
(
|
|
quantized_out,
|
|
all_indices,
|
|
all_commit_losses,
|
|
all_codebook_losses,
|
|
_,
|
|
) = self.quantizer(x)
|
|
|
|
# decoder
|
|
x = self.decoder(quantized_out)
|
|
|
|
# up
|
|
if self.downsample_scale != None and self.downsample_scale > 1:
|
|
x = x.transpose(1, 2)
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
|
x_rec = self.up(x).transpose(1, 2)
|
|
|
|
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
|
all_indices = all_indices
|
|
|
|
return x_rec, codebook_loss, all_indices
|
|
|
|
def quantize(self, x):
|
|
|
|
if self.downsample_scale != None and self.downsample_scale > 1:
|
|
x = x.transpose(1, 2)
|
|
x = self.down(x)
|
|
x = F.gelu(x)
|
|
x = x.transpose(1, 2)
|
|
|
|
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
|
|
|
(
|
|
quantized_out,
|
|
all_indices,
|
|
all_commit_losses,
|
|
all_codebook_losses,
|
|
_,
|
|
) = self.quantizer(x)
|
|
|
|
if all_indices.shape[0] == 1:
|
|
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
|
return all_indices, quantized_out.transpose(1, 2)
|
|
|
|
def reset_parameters(self):
|
|
self.apply(init_weights)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
|
|
print(repcodec)
|
|
print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
|
|
x = torch.randn(5, 10, 1024)
|
|
x_rec, codebook_loss, all_indices = repcodec(x)
|
|
print(x_rec.shape, codebook_loss, all_indices.shape)
|
|
vq_id, emb = repcodec.quantize(x)
|
|
print(vq_id.shape, emb.shape)
|