mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 02:10:23 +08:00
120 lines
4 KiB
Python
120 lines
4 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from indextts.utils.xtransformers import RelativePositionBias
|
|
|
|
|
|
def zero_module(module):
|
|
"""
|
|
Zero out the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
|
|
class GroupNorm32(nn.GroupNorm):
|
|
def forward(self, x):
|
|
return super().forward(x.float()).type(x.dtype)
|
|
|
|
|
|
def normalization(channels):
|
|
"""
|
|
Make a standard normalization layer.
|
|
|
|
:param channels: number of input channels.
|
|
:return: an nn.Module for normalization.
|
|
"""
|
|
groups = 32
|
|
if channels <= 16:
|
|
groups = 8
|
|
elif channels <= 64:
|
|
groups = 16
|
|
while channels % groups != 0:
|
|
groups = int(groups / 2)
|
|
assert groups > 2
|
|
return GroupNorm32(groups, channels)
|
|
|
|
|
|
class QKVAttentionLegacy(nn.Module):
|
|
"""
|
|
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv, mask=None, rel_pos=None):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = torch.einsum(
|
|
"bct,bcs->bts", q * scale, k * scale
|
|
) # More stable with f16 than dividing afterwards
|
|
if rel_pos is not None:
|
|
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
if mask is not None:
|
|
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
|
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
|
weight = weight * mask
|
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
|
|
|
return a.reshape(bs, -1, length)
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
"""
|
|
An attention block that allows spatial positions to attend to each other.
|
|
|
|
Originally ported from here, but adapted to the N-d case.
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
num_heads=1,
|
|
num_head_channels=-1,
|
|
do_checkpoint=True,
|
|
relative_pos_embeddings=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.do_checkpoint = do_checkpoint
|
|
if num_head_channels == -1:
|
|
self.num_heads = num_heads
|
|
else:
|
|
assert (
|
|
channels % num_head_channels == 0
|
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
self.num_heads = channels // num_head_channels
|
|
self.norm = normalization(channels)
|
|
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
|
# split heads before split qkv
|
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
|
|
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
|
if relative_pos_embeddings:
|
|
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
|
else:
|
|
self.relative_pos_embeddings = None
|
|
|
|
def forward(self, x, mask=None):
|
|
b, c, *spatial = x.shape
|
|
x = x.reshape(b, c, -1)
|
|
qkv = self.qkv(self.norm(x))
|
|
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
|
h = self.proj_out(h)
|
|
return (x + h).reshape(b, c, *spatial)
|