This repository contains the optimized CUDA kernel implementation for InfLLM V2's Two-Stage Sparse Attention Mechanism. Our implementation provides high-performance kernels for both Stage 1 (Top-K Context Selection) and Stage 2 (Sparse Attention Computation), enabling Large Language Models (LLMs) to efficiently process long contexts with trainable sparse patterns.
InfLLM V2 introduces a novel two-stage approach for efficient long-context processing:
- Stage 1: Top-K Context Selection: Block scoring and aggregation using semantic kernels (kernel computes and aggregates scores, selection performed externally)
- Stage 2: Sparse Attention Computation: Attention calculation on selected blocks
This CUDA kernel implementation includes both stages, providing:
- Optimized relevance score computation and aggregation for Stage 1 (Top-K selection performed externally)
- Efficient sparse attention on selected blocks for Stage 2
- Significant reduction in computational costs for both forward and backward phases
Built upon FlashAttention, our kernels leverage efficient memory access patterns and optimized implementations for both stages.
- Open-source Base Model and Training Data: Release base model and training datasets to enable the open-source community to reproduce the training process
- Optimize Stage 1 Inference Operator: Enhance the Stage 1 inference kernel by adding compressed LSE (LogSumExp) computation for improved efficiency
- Optimize Training Operator Speed: Further optimize the training kernel performance for faster model training
The Top-K selection stage involves three sequential steps:
- Relevance Score Computation: Computing scores between query tokens and each semantic kernel (compressed representations of key-value blocks), followed by softmax normalization
- Score Aggregation: Aggregating relevance scores for each semantic kernel across the query group dimension using dimension reduction (hdim16_reduce)
- Block Selection (Post-processing): Selecting the top-K context blocks for each query token based on the aggregated scores
Note: The infllmv2_attn_stage1
kernel handles steps 1 and 2 (score computation and aggregation). Only step 3 (Top-K selection) is performed outside the kernel.
The sparse attention stage performs standard attention computation, but only on the blocks selected in Stage 1:
- Support for both forward and backward passes
- Efficient memory access through block-sparse patterns
- Token-level Query, Block-level Key-Value: Avoids training-inference inconsistency during decoding
- Trainable Context Selection: Semantic kernels updated indirectly through token-level key vector optimization
- Selective Block Attention: Performs attention only on blocks selected in Stage 1
infllmv2_attn_stage1
: Calculates similarity scores between query tokens and compressed key representations- Performs score aggregation across query group dimension (hdim16_reduce)
- Returns aggregated attention scores for subsequent Top-K selection (selection performed outside the kernel)
- Support for causal masking and variable sequence lengths
infllmv2_sparse_attn_fwd
: Forward pass kernel for sparse attentioninfllmv2_sparse_attn_bwd
: Backward pass kernel for training
- PyTorch 1.12+
- CUDA 11.6+ (with CUDA development toolkit)
- Python 3.7+
- Linux operating system
- Sufficient GPU memory for kernel compilation
- Ninja build system (for faster compilation)
# Install with CUDA kernel compilation
pip install -e .
The InfLLM V2 CUDA kernel provides the following interfaces for the two-stage sparse attention:
from infllm_v2 import infllmv2_attn_stage1
# Stage 1: Compute and aggregate relevance scores between queries and semantic kernels
# This kernel performs:
# 1. LSE approximation using compressed keys
# 2. Full attention score computation
# 3. Score aggregation across query group dimension (hdim16_reduce)
# Top-K selection must be performed separately on the aggregated scores
#
# Inputs:
# - q: Query tensor (batch_size * n_heads, seqlen_q, head_dim)
# - k: Compressed key tensor representing semantic kernels
# - v: Placeholder tensor (not used in score computation)
# - cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths
# - max_seqlen_q, max_seqlen_k: Maximum sequence lengths
# Returns aggregated attention scores for subsequent Top-K selection
aggregated_scores = infllmv2_attn_stage1(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=True, # Apply causal masking
return_attn_probs=True # Return attention scores
)
# Top-K selection should be performed on the returned aggregated scores
# (This step is not part of the kernel)
from infllm_v2 import infllmv2_attn_varlen_func
# Stage 2: Sparse Attention Computation Kernel
# Inputs:
# - q_unpad: Queries tensor (token-level)
# - k_unpad, v_unpad: Keys and Values tensors (block-level)
# - cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths
# - topk_idx: Selected block indices from Stage 1
# - max_seqlen_q, max_seqlen_k: Maximum sequence lengths
out_unpad = infllmv2_attn_varlen_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
topk_idx, # Block indices selected in Stage 1
max_seqlen_q, max_seqlen_k
)
- q: Query tensor with shape (batch_size * n_heads, seqlen_q, head_dim)
- k: Compressed key tensor representing semantic kernels
- causal: Whether to apply causal masking
- return_attn_probs: Whether to return attention scores (required for Top-K selection)
- Output: Aggregated attention scores matrix (reduced along query group dimension) for external Top-K selection
- q_unpad: Query tensor in unpadded format (bfloat16)
- k_unpad, v_unpad: Key and Value tensors in unpadded format
- topk_idx: Integer tensor containing selected block indices from Stage 1
- The kernel automatically handles different GPU architectures (SM80/SM90)
- Optimized for batch processing with variable sequence lengths
- Memory efficient through unpadded tensor format and block-sparse patterns
- Supports bfloat16 precision for both stages
- SM 80: A100
- SM 90: H100
All benchmarks were conducted with the following configuration:
- GPU: NVIDIA H100
- Head Dimension: 128
- Number of Heads: 2
- Query Heads: 32
- Block Size: 64
- Selected Blocks: 64
- Attention Type: Causal
Sequence Length | Batch Size | Implementation | Forward (ms) | Backward (ms) | Combined (ms) | Speedup vs FlashAttention |
---|---|---|---|---|---|---|
32,768 | 8 | Flash Attention | 201.46 | 526.62 | 728.08 | 1x |
32,768 | 8 | Triton NSA | 169.11 | 343.82 | 512.93 | 1.42x |
32,768 | 8 | InfLLMv2 | 133.60 | 330.04 | 463.64 | 1.57x |
65,536 | 4 | Flash Attention | 409.29 | 1037.46 | 1446.75 | 1x |
65,536 | 4 | Triton NSA | 181.88 | 469.00 | 650.88 | 2.22x |
65,536 | 4 | InfLLMv2 | 142.31 | 381.55 | 523.86 | 2.76x |
131,072 | 2 | Flash Attention | 831.77 | 2063.11 | 2894.88 | 1x |
131,072 | 2 | Triton NSA | 216.10 | 589.66 | 805.76 | 3.59x |
131,072 | 2 | InfLLMv2 | 158.42 | 468.90 | 627.32 | 4.61x |
If you use the InfLLM V2 CUDA kernels in your research, please cite:
@article{infllmv2,
title={InfLLM-V2: Dense-Sparse Switchable Attention for Seamless Short-to-Long Adaptation},
author={Zhao, Weilin and Zhou, Zihan and Su, Zhou and Xiao, Chaojun and Li, Yuxuan and Li, Yanghao and Zhang, Yudi and Zhao, Weilun and Li, Zhen and Huang, Yuxiang and Sun, Ao and Han, Xu and Liu, Zhiyuan},
journal={arXiv preprint arXiv:2509.24663},
year={2025}
}
@article{minicpm4,
title={MiniCPM4: Ultra-Efficient LLMs on End Devices},
author={MiniCPM},
year={2025}
}
- MiniCPM4: For model integration and testing
- FlashAttention: The foundational CUDA kernel architecture we built upon
- Block Sparse Attention: Inspiration for block-sparse kernel design
- This repository is released under the Apache-2.0 License.