mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 10:20:24 +08:00
154 lines
4.2 KiB
Python
154 lines
4.2 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
|
from torch import nn
|
|
|
|
|
|
@dataclass
|
|
class ForwardContext:
|
|
is_prefill: bool = False
|
|
cu_seqlens_q: torch.Tensor | None = None
|
|
cu_seqlens_k: torch.Tensor | None = None
|
|
max_seqlen_q: int = 0
|
|
max_seqlen_k: int = 0
|
|
slot_mapping: torch.Tensor | None = None
|
|
context_lens: torch.Tensor | None = None
|
|
block_tables: torch.Tensor | None = None
|
|
|
|
|
|
_FORWARD_CONTEXT = ForwardContext()
|
|
|
|
|
|
def get_forward_context():
|
|
return _FORWARD_CONTEXT
|
|
|
|
|
|
def set_forward_context(
|
|
is_prefill,
|
|
cu_seqlens_q=None,
|
|
cu_seqlens_k=None,
|
|
max_seqlen_q=0,
|
|
max_seqlen_k=0,
|
|
slot_mapping=None,
|
|
context_lens=None,
|
|
block_tables=None,
|
|
):
|
|
global _FORWARD_CONTEXT
|
|
_FORWARD_CONTEXT = ForwardContext(
|
|
is_prefill,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
slot_mapping,
|
|
context_lens,
|
|
block_tables,
|
|
)
|
|
|
|
|
|
def reset_forward_context():
|
|
global _FORWARD_CONTEXT
|
|
_FORWARD_CONTEXT = ForwardContext()
|
|
|
|
|
|
@triton.jit
|
|
def store_kvcache_kernel(
|
|
key_ptr,
|
|
key_stride,
|
|
value_ptr,
|
|
value_stride,
|
|
k_cache_ptr,
|
|
v_cache_ptr,
|
|
slot_mapping_ptr,
|
|
D: tl.constexpr,
|
|
):
|
|
BLOCK_SIZE: tl.constexpr = 2048
|
|
idx = tl.program_id(0)
|
|
slot = tl.load(slot_mapping_ptr + idx)
|
|
if slot == -1:
|
|
return
|
|
d_offset = 0
|
|
while d_offset < D:
|
|
cur_block_size = min(BLOCK_SIZE, D - d_offset)
|
|
key_offsets = idx * key_stride + d_offset + tl.arange(0, BLOCK_SIZE)
|
|
value_offsets = idx * value_stride + d_offset + tl.arange(0, BLOCK_SIZE)
|
|
cache_offsets = slot * D + d_offset + tl.arange(0, BLOCK_SIZE)
|
|
|
|
mask = tl.arange(0, BLOCK_SIZE) < cur_block_size
|
|
key = tl.load(key_ptr + key_offsets, mask=mask, other=0.0)
|
|
value = tl.load(value_ptr + value_offsets, mask=mask, other=0.0)
|
|
tl.store(k_cache_ptr + cache_offsets, key, mask=mask)
|
|
tl.store(v_cache_ptr + cache_offsets, value, mask=mask)
|
|
|
|
d_offset += BLOCK_SIZE
|
|
|
|
|
|
def store_kvcache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
):
|
|
N, num_heads, head_dim = key.shape
|
|
D = num_heads * head_dim
|
|
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
|
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
|
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
|
assert slot_mapping.numel() == N
|
|
store_kvcache_kernel[(N,)](
|
|
key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D
|
|
)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.scale = scale
|
|
self.num_kv_heads = num_kv_heads
|
|
self.k_cache = self.v_cache = torch.tensor([])
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
|
context = get_forward_context()
|
|
k_cache, v_cache = self.k_cache, self.v_cache
|
|
|
|
if k_cache.numel() and v_cache.numel() and context.slot_mapping is not None:
|
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
|
|
|
if context.is_prefill:
|
|
if context.block_tables is not None:
|
|
k, v = k_cache, v_cache
|
|
o = flash_attn_varlen_func(
|
|
q,
|
|
k,
|
|
v,
|
|
max_seqlen_q=context.max_seqlen_q,
|
|
cu_seqlens_q=context.cu_seqlens_q,
|
|
max_seqlen_k=context.max_seqlen_k,
|
|
cu_seqlens_k=context.cu_seqlens_k,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
block_table=context.block_tables,
|
|
)
|
|
else:
|
|
o = flash_attn_with_kvcache(
|
|
q.unsqueeze(1),
|
|
k_cache,
|
|
v_cache,
|
|
cache_seqlens=context.context_lens,
|
|
block_table=context.block_tables,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
)
|
|
return o
|