index-tts/indextts/s2mel/modules/flow_matching.py
storyicon 31e7e855e2 feat: optimize s2mel stage
Signed-off-by: storyicon <storyicon@foxmail.com>
2025-10-24 07:30:20 +00:00

186 lines
7.2 KiB
Python

from abc import ABC
import torch
import torch.nn.functional as F
from indextts.s2mel.modules.diffusion_transformer import DiT
from indextts.s2mel.modules.commons import sequence_mask
from tqdm import tqdm
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
args,
):
super().__init__()
self.sigma_min = 1e-6
self.estimator = None
self.in_channels = args.DiT.in_channels
self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
if hasattr(args.DiT, 'zero_prompt_speech_token'):
self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
else:
self.zero_prompt_speech_token = False
@torch.inference_mode()
def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
"""Forward diffusion
Args:
mu (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
prompt (torch.Tensor): reference mel
shape: (batch_size, 80, 795)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
f0: None
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, 80, mel_timesteps)
"""
B, T = mu.size(0), mu.size(1)
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
# t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
prompt (torch.Tensor): reference mel
shape: (batch_size, 80, 795)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
"""
t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# apply prompt
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0
if self.zero_prompt_speech_token:
mu[..., :prompt_len] = 0
for step in tqdm(range(1, len(t_span))):
dt = t_span[step] - t_span[step - 1]
if inference_cfg_rate > 0:
# Stack original and CFG (null) inputs for batched processing
stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
stacked_x = torch.cat([x, x], dim=0)
stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
# Perform a single forward pass for both original and CFG inputs
stacked_dphi_dt = self.estimator(
stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
)
# Split the output back into the original and CFG components
dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
# Apply CFG formula
dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
else:
dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
x[:, :, :prompt_len] = 0
return sol[-1]
def forward(self, x1, x_lens, prompt_lens, mu, style):
"""Computes diffusion loss
Args:
mu (torch.Tensor): semantic info of reference audio and altered audio
shape: (batch_size, mel_timesteps(795+1069), 512)
x1: mel
x_lens (torch.Tensor): mel frames output
shape: (batch_size, mel_timesteps)
prompt (torch.Tensor): reference mel
shape: (batch_size, 80, 795)
style (torch.Tensor): reference global style
shape: (batch_size, 192)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = x1.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
prompt = torch.zeros_like(x1)
for bib in range(b):
prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
# range covered by prompt are set to 0
y[bib, :, :prompt_lens[bib]] = 0
if self.zero_prompt_speech_token:
mu[bib, :, :prompt_lens[bib]] = 0
estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
loss = 0
for bib in range(b):
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
loss /= b
return loss, estimator_out + (1 - self.sigma_min) * z
class CFM(BASECFM):
def __init__(self, args):
super().__init__(
args
)
if args.dit_type == "DiT":
self.estimator = DiT(args)
else:
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
def enable_torch_compile(self):
"""Enable torch.compile optimization for the estimator model.
This method applies torch.compile to the estimator (DiT model) for significant
performance improvements during inference. It also configures distributed
training optimizations if applicable.
"""
if torch.distributed.is_initialized():
torch._inductor.config.reorder_for_compute_comm_overlap = True
self.estimator = torch.compile(
self.estimator,
fullgraph=True,
dynamic=True,
)