mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
184 lines
5.8 KiB
Python
184 lines
5.8 KiB
Python
# Copyright (c) 2023 Amphion.
|
|
#
|
|
# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py
|
|
# Licensed under Apache License 2.0
|
|
|
|
from .modules.seanet import SEANetEncoder, SEANetDecoder
|
|
from .modules.quantization import ResidualVectorQuantizer
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
class SpeechTokenizer(nn.Module):
|
|
def __init__(self, config):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
config : json
|
|
Model Config.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.encoder = SEANetEncoder(
|
|
n_filters=config.get("n_filters"),
|
|
dimension=config.get("dimension"),
|
|
ratios=config.get("strides"),
|
|
lstm=config.get("lstm_layers"),
|
|
bidirectional=config.get("bidirectional"),
|
|
dilation_base=config.get("dilation_base"),
|
|
residual_kernel_size=config.get("residual_kernel_size"),
|
|
n_residual_layers=config.get("n_residual_layers"),
|
|
activation=config.get("activation"),
|
|
)
|
|
self.sample_rate = config.get("sample_rate")
|
|
self.n_q = config.get("n_q")
|
|
self.downsample_rate = np.prod(config.get("strides"))
|
|
if config.get("dimension") != config.get("semantic_dimension"):
|
|
self.transform = nn.Linear(
|
|
config.get("dimension"), config.get("semantic_dimension")
|
|
)
|
|
else:
|
|
self.transform = nn.Identity()
|
|
self.quantizer = ResidualVectorQuantizer(
|
|
dimension=config.get("dimension"),
|
|
n_q=config.get("n_q"),
|
|
bins=config.get("codebook_size"),
|
|
)
|
|
self.decoder = SEANetDecoder(
|
|
n_filters=config.get("n_filters"),
|
|
dimension=config.get("dimension"),
|
|
ratios=config.get("strides"),
|
|
lstm=config.get("lstm_layers"),
|
|
bidirectional=False,
|
|
dilation_base=config.get("dilation_base"),
|
|
residual_kernel_size=config.get("residual_kernel_size"),
|
|
n_residual_layers=config.get("n_residual_layers"),
|
|
activation=config.get("activation"),
|
|
)
|
|
|
|
@classmethod
|
|
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
config_path : str
|
|
Path of model configuration file.
|
|
ckpt_path : str
|
|
Path of model checkpoint.
|
|
|
|
Returns
|
|
-------
|
|
model : SpeechTokenizer
|
|
SpeechTokenizer model.
|
|
|
|
"""
|
|
import json
|
|
|
|
with open(config_path) as f:
|
|
cfg = json.load(f)
|
|
model = cls(cfg)
|
|
params = torch.load(ckpt_path, map_location="cpu")
|
|
model.load_state_dict(params)
|
|
return model
|
|
|
|
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.tensor
|
|
Input wavs. Shape: (batch, channels, timesteps).
|
|
n_q : int, optional
|
|
Number of quantizers in RVQ used to encode. The default is all layers.
|
|
layers : list[int], optional
|
|
Layers of RVQ should return quantized result. The default is the first layer.
|
|
|
|
Returns
|
|
-------
|
|
o : torch.tensor
|
|
Output wavs. Shape: (batch, channels, timesteps).
|
|
commit_loss : torch.tensor
|
|
Commitment loss from residual vector quantizers.
|
|
feature : torch.tensor
|
|
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
|
|
|
|
"""
|
|
n_q = n_q if n_q else self.n_q
|
|
e = self.encoder(x)
|
|
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
|
e, n_q=n_q, layers=layers
|
|
)
|
|
feature = rearrange(quantized_list[0], "b d t -> b t d")
|
|
feature = self.transform(feature)
|
|
o = self.decoder(quantized)
|
|
return o, commit_loss, feature
|
|
|
|
def forward_feature(self, x: torch.tensor, layers: list = None):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.tensor
|
|
Input wavs. Shape should be (batch, channels, timesteps).
|
|
layers : list[int], optional
|
|
Layers of RVQ should return quantized result. The default is all layers.
|
|
|
|
Returns
|
|
-------
|
|
quantized_list : list[torch.tensor]
|
|
Quantized of required layers.
|
|
|
|
"""
|
|
e = self.encoder(x)
|
|
layers = layers if layers else list(range(self.n_q))
|
|
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
|
|
return quantized_list
|
|
|
|
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.tensor
|
|
Input wavs. Shape: (batch, channels, timesteps).
|
|
n_q : int, optional
|
|
Number of quantizers in RVQ used to encode. The default is all layers.
|
|
st : int, optional
|
|
Start quantizer index in RVQ. The default is 0.
|
|
|
|
Returns
|
|
-------
|
|
codes : torch.tensor
|
|
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
|
|
|
|
"""
|
|
e = self.encoder(x)
|
|
if st is None:
|
|
st = 0
|
|
n_q = n_q if n_q else self.n_q
|
|
codes = self.quantizer.encode(e, n_q=n_q, st=st)
|
|
return codes
|
|
|
|
def decode(self, codes: torch.tensor, st: int = 0):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
codes : torch.tensor
|
|
Indices for each quantizer. Shape: (n_q, batch, timesteps).
|
|
st : int, optional
|
|
Start quantizer index in RVQ. The default is 0.
|
|
|
|
Returns
|
|
-------
|
|
o : torch.tensor
|
|
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
|
|
|
|
"""
|
|
quantized = self.quantizer.decode(codes, st=st)
|
|
o = self.decoder(quantized)
|
|
return o
|