class TritonMLASparseImpl(XPUMLASparseImpl):
"""Triton sparse-MLA impl with split-KV decode (3-7× faster than the
single-pass XPU base for single-query decode on SM80 / SM121)."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sm_count: int | None = None
if self.topk_indices_buffer is not None:
self._sm_count = num_compute_units(self.topk_indices_buffer.device.index)
self._warmup_autotune(kwargs["indexer"])
def _warmup_autotune(self, indexer) -> None:
"""Prime `@triton.autotune` caches at init so the first request
doesn't pay the inline config-sweep cost."""
if self.topk_indices_buffer is None:
return
device = self.topk_indices_buffer.device
topk = self.topk_indices_buffer.shape[-1]
q = torch.empty(1, self.num_heads, _DIM_QK, dtype=torch.bfloat16, device=device)
kv = torch.empty(64, 1, _DIM_QK, dtype=torch.bfloat16, device=device)
indices = torch.zeros(1, 1, topk, dtype=torch.int32, device=device)
for splits in KV_SPLITS_CANDIDATES:
triton_mla_sparse_attention(
q,
kv,
indices,
sm_scale=self.softmax_scale,
num_kv_splits=splits,
sm_count=self._sm_count,
)
indexer_num_heads = getattr(indexer, "n_head", _INDEXER_NUM_HEADS)
indexer_head_dim = getattr(indexer, "head_dim", _INDEXER_HEAD_DIM)
warmup_fp8_mqa_logits_triton(
num_heads=indexer_num_heads, head_dim=indexer_head_dim, device=device
)
cfg = get_current_vllm_config_or_none()
if cfg is not None:
warmup_fp8_paged_mqa_logits_triton(
num_heads=indexer_num_heads,
head_dim=indexer_head_dim,
block_size=cfg.cache_config.block_size,
device=device,
)
def _forward_bf16_kv(
self,
q: torch.Tensor, # [sq, heads, d_qk]
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
topk_indices: torch.Tensor, # [sq, topk]
attn_metadata: XPUMLASparseMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = triton_mla_sparse_attention(
q,
kv_c_and_k_pe_cache,
topk_indices,
sm_scale=self.softmax_scale,
sm_count=self._sm_count,
)
return output[:, : self.num_heads, :]