Skip to content

vllm.v1.attention.backends.mla.triton_mla_sparse

Pure-Triton sparse MLA backend for SM80 (A100) / SM121 (GB10).

TritonMLASparseImpl

Bases: XPUMLASparseImpl

Triton sparse-MLA impl with split-KV decode (3-7× faster than the single-pass XPU base for single-query decode on SM80 / SM121).

Source code in vllm/v1/attention/backends/mla/triton_mla_sparse.py
class TritonMLASparseImpl(XPUMLASparseImpl):
    """Triton sparse-MLA impl with split-KV decode (3-7× faster than the
    single-pass XPU base for single-query decode on SM80 / SM121)."""

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._sm_count: int | None = None
        if self.topk_indices_buffer is not None:
            self._sm_count = num_compute_units(self.topk_indices_buffer.device.index)
        self._warmup_autotune(kwargs["indexer"])

    def _warmup_autotune(self, indexer) -> None:
        """Prime `@triton.autotune` caches at init so the first request
        doesn't pay the inline config-sweep cost."""
        if self.topk_indices_buffer is None:
            return
        device = self.topk_indices_buffer.device
        topk = self.topk_indices_buffer.shape[-1]
        q = torch.empty(1, self.num_heads, _DIM_QK, dtype=torch.bfloat16, device=device)
        kv = torch.empty(64, 1, _DIM_QK, dtype=torch.bfloat16, device=device)
        indices = torch.zeros(1, 1, topk, dtype=torch.int32, device=device)
        for splits in KV_SPLITS_CANDIDATES:
            triton_mla_sparse_attention(
                q,
                kv,
                indices,
                sm_scale=self.softmax_scale,
                num_kv_splits=splits,
                sm_count=self._sm_count,
            )
        indexer_num_heads = getattr(indexer, "n_head", _INDEXER_NUM_HEADS)
        indexer_head_dim = getattr(indexer, "head_dim", _INDEXER_HEAD_DIM)
        warmup_fp8_mqa_logits_triton(
            num_heads=indexer_num_heads, head_dim=indexer_head_dim, device=device
        )
        cfg = get_current_vllm_config_or_none()
        if cfg is not None:
            warmup_fp8_paged_mqa_logits_triton(
                num_heads=indexer_num_heads,
                head_dim=indexer_head_dim,
                block_size=cfg.cache_config.block_size,
                device=device,
            )

    def _forward_bf16_kv(
        self,
        q: torch.Tensor,  # [sq, heads, d_qk]
        kv_c_and_k_pe_cache: torch.Tensor,  # [blocks, heads, d_qk]
        topk_indices: torch.Tensor,  # [sq, topk]
        attn_metadata: XPUMLASparseMetadata,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
            -1, 1, kv_c_and_k_pe_cache.shape[-1]
        )
        topk_indices = topk_indices.view(num_tokens, 1, -1)
        output = triton_mla_sparse_attention(
            q,
            kv_c_and_k_pe_cache,
            topk_indices,
            sm_scale=self.softmax_scale,
            sm_count=self._sm_count,
        )
        return output[:, : self.num_heads, :]

_warmup_autotune

_warmup_autotune(indexer) -> None

Prime @triton.autotune caches at init so the first request doesn't pay the inline config-sweep cost.

Source code in vllm/v1/attention/backends/mla/triton_mla_sparse.py
def _warmup_autotune(self, indexer) -> None:
    """Prime `@triton.autotune` caches at init so the first request
    doesn't pay the inline config-sweep cost."""
    if self.topk_indices_buffer is None:
        return
    device = self.topk_indices_buffer.device
    topk = self.topk_indices_buffer.shape[-1]
    q = torch.empty(1, self.num_heads, _DIM_QK, dtype=torch.bfloat16, device=device)
    kv = torch.empty(64, 1, _DIM_QK, dtype=torch.bfloat16, device=device)
    indices = torch.zeros(1, 1, topk, dtype=torch.int32, device=device)
    for splits in KV_SPLITS_CANDIDATES:
        triton_mla_sparse_attention(
            q,
            kv,
            indices,
            sm_scale=self.softmax_scale,
            num_kv_splits=splits,
            sm_count=self._sm_count,
        )
    indexer_num_heads = getattr(indexer, "n_head", _INDEXER_NUM_HEADS)
    indexer_head_dim = getattr(indexer, "head_dim", _INDEXER_HEAD_DIM)
    warmup_fp8_mqa_logits_triton(
        num_heads=indexer_num_heads, head_dim=indexer_head_dim, device=device
    )
    cfg = get_current_vllm_config_or_none()
    if cfg is not None:
        warmup_fp8_paged_mqa_logits_triton(
            num_heads=indexer_num_heads,
            head_dim=indexer_head_dim,
            block_size=cfg.cache_config.block_size,
            device=device,
        )