Skip to content

[Perf][Qwen3-next]: torch.compile GDN attn #27152

@vadiklyutiy

Description

@vadiklyutiy

Proposal to improve performance

Right now GDN attn (Qwen3NextGatedDeltaNet) (used in Qwen3-next) aren't covered by torch.compile.
GDN unlike full attn contain a lot of operators including a lot of elementwise. Below is illustration

Image

Right now GDN attn implemented as custom op and torch.compile doesn't go inside.

I wrote the following script that call vllm's GDN attn and measured performance.

Benchmark script
#!/usr/bin/env python3
"""
Standalone benchmark for Qwen3NextGatedDeltaNet._forward method
"""

import os
import sys
import time
import torch
import numpy as np
from dataclasses import dataclass
from typing import Optional, Dict, Any

# Add the current directory to Python path to import vllm modules
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

# Import necessary vLLM components
from vllm.config import ModelConfig, CacheConfig, VllmConfig
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.forward_context import ForwardContext, set_forward_context
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.platforms import current_platform
from vllm.attention.backends.utils import PAD_SLOT_ID


@dataclass
class BenchmarkConfig:
    """Configuration for the benchmark"""
    batch_size: int = 8192
    hidden_size: int = 2048
    num_warmup_iters: int = 10
    num_bench_iters: int = 100
    device: str = "cuda"
    dtype: torch.dtype = torch.bfloat16


def create_qwen3_config():
    """Create Qwen3NextConfig with the exact parameters provided"""
    config_dict = {
        "architectures": ["Qwen3NextForCausalLM"],
        "attention_bias": False,
        "attention_dropout": 0.0,
        "bos_token_id": 151643,
        "decoder_sparse_step": 1,
        "dtype": "bfloat16",
        "eos_token_id": 151645,
        "full_attention_interval": 4,
        "head_dim": 256,
        "hidden_act": "silu",
        "hidden_size": 2048,
        "initializer_range": 0.02,
        "intermediate_size": 5120,
        "layer_types": [
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention",
            "linear_attention", "linear_attention", "linear_attention", "full_attention"
        ],
        "linear_conv_kernel_dim": 4,
        "linear_key_head_dim": 128,
        "linear_num_key_heads": 16,
        "linear_num_value_heads": 32,
        "linear_value_head_dim": 128,
        "max_position_embeddings": 262144,
        "mlp_only_layers": [],
        "model_type": "qwen3_next",
        "moe_intermediate_size": 512,
        "norm_topk_prob": True,
        "num_attention_heads": 16,
        "num_experts": 512,
        "num_experts_per_tok": 10,
        "num_hidden_layers": 48,
        "num_key_value_heads": 2,
        "output_router_logits": False,
        "partial_rotary_factor": 0.25,
        "rms_norm_eps": 1e-06,
        "rope_scaling": None,
        "rope_theta": 10000000,
        "router_aux_loss_coef": 0.001,
        "shared_expert_intermediate_size": 512,
        "tie_word_embeddings": False,
        "transformers_version": "4.56.1",
        "use_cache": True,
        "use_sliding_window": False,
        "vocab_size": 151936,
        "torch_dtype": torch.bfloat16
    }
    
    return Qwen3NextConfig(**config_dict)


def create_mock_metadata(batch_size: int, device: str):
    """Create mock GDNAttentionMetadata for testing"""
    # Create metadata with all required fields
    
    # Create nums_dict for causal_conv1d
    query_start_loc = torch.tensor([0, batch_size], dtype=torch.int32, device='cpu')
    seqlens = query_start_loc.diff()  # Single sequence of length batch_size
    
    nums_dict = {}
    for BLOCK_M in [8]:  # cover all BLOCK_M values
        nums = -(-seqlens // BLOCK_M)  # Ceiling division
        nums_dict[BLOCK_M] = {}
        nums_dict[BLOCK_M]['nums'] = nums
        nums_dict[BLOCK_M]['tot'] = nums.sum().item()
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
        nums_dict[BLOCK_M]['mlist'] = mlist
        mlist_len = len(nums_dict[BLOCK_M]['mlist'])
        nums_dict[BLOCK_M]['mlist_len'] = mlist_len
        MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
        offsetlist = []
        for idx, num in enumerate(nums):
            offsetlist.extend(range(num))
        offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
        nums_dict[BLOCK_M]['offsetlist'] = offsetlist
        
        # Create batch_ptr and token_chunk_offset_ptr
        batch_ptr = torch.full((MAX_NUM_PROGRAMS,), PAD_SLOT_ID, 
                             dtype=torch.int32, device=device)
        token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS,), PAD_SLOT_ID,
                                          dtype=torch.int32, device=device)
        batch_ptr[0:mlist_len].copy_(mlist.to(device))
        token_chunk_offset_ptr[0:mlist_len].copy_(offsetlist.to(device))
        nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
        nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = token_chunk_offset_ptr
    
    metadata = GDNAttentionMetadata(
        num_prefills=1,  # Treating all as one prefill request
        num_prefill_tokens=batch_size,
        num_decodes=0,
        num_decode_tokens=0,
        num_spec_decodes=0,
        num_spec_decode_tokens=0,
        num_actual_tokens=batch_size,
        has_initial_state=torch.zeros(1, dtype=torch.bool, device=device),  # No initial state for prefill
        spec_query_start_loc=None,
        non_spec_query_start_loc=torch.tensor([0, batch_size], dtype=torch.int32, device=device),
        spec_state_indices_tensor=None,
        non_spec_state_indices_tensor=torch.zeros(1, dtype=torch.int32, device=device),
        spec_sequence_masks=None,
        spec_token_indx=None,
        non_spec_token_indx=None,
        num_accepted_tokens=None,
        nums_dict=nums_dict,
        batch_ptr=batch_ptr,
        token_chunk_offset_ptr=token_chunk_offset_ptr,
    )
    
    return metadata


def setup_model(config: BenchmarkConfig):
    """Initialize the model and required components"""
    # Initialize distributed environment (single GPU)
    if not torch.distributed.is_initialized():
        # Set required environment variables for single GPU setup
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = "29500"
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["LOCAL_RANK"] = "0"
        
        init_distributed_environment()
        initialize_model_parallel(1, 1)
    
    # Create configurations
    qwen_config = create_qwen3_config()
    
    # Create model config
    model_config = ModelConfig(
        model='Qwen/Qwen3-Next-80B-A3B-Instruct',
        tokenizer='Qwen/Qwen3-Next-80B-A3B-Instruct',
        tokenizer_mode='auto',
        trust_remote_code=False,
        dtype=config.dtype,
        max_model_len=262144,
    )
    
    # Create cache config
    cache_config = CacheConfig(
        block_size=272,
        gpu_memory_utilization=0.9,
        swap_space=4.0,
        cache_dtype='auto',
        mamba_page_size_padded=278528,
        mamba_cache_dtype='auto',
        mamba_ssm_cache_dtype='auto',
    )
    
    # Create VllmConfig with compilation config that has static_forward_context
    from vllm.config import CompilationConfig, ParallelConfig, SchedulerConfig, VllmConfig as VllmConfigClass
    
    compilation_config = CompilationConfig()
    compilation_config.static_forward_context = {}
    
    parallel_config = ParallelConfig(
        tensor_parallel_size=1,
        pipeline_parallel_size=1,
    )
    
    scheduler_config = SchedulerConfig()
    
    # Create VllmConfig properly
    vllm_config = VllmConfigClass(
        model_config=model_config,
        cache_config=cache_config,
        parallel_config=parallel_config,
        scheduler_config=scheduler_config,
        compilation_config=compilation_config,
    )
    
    # Set the current vllm config
    from vllm.config import set_current_vllm_config
    set_current_vllm_config(vllm_config)
    
    # Initialize the model
    prefix = "model.layers.0.linear_attn"
    model = Qwen3NextGatedDeltaNet(
        config=qwen_config,
        model_config=model_config,
        cache_config=cache_config,
        quant_config=None,
        speculative_config=None,
        prefix=prefix
    )
    
    # Move model to device and dtype
    model = model.to(config.device, dtype=config.dtype)
    model.eval()
    
    # Ensure all parameters are in the correct dtype
    # Note: A_log stays in float32 per design
    for name, param in model.named_parameters():
        if "A_log" not in name:
            param.data = param.data.to(config.dtype)
    
    # Initialize KV cache
    conv_state_shape, ssm_state_shape = model.get_state_shape()
    conv_state_dtype, ssm_state_dtype = model.get_state_dtype()
    
    print(f"Conv state shape (per sequence): {conv_state_shape}")
    print(f"SSM state shape (per sequence): {ssm_state_shape}")
    
    # Assuming single virtual engine
    model.kv_cache = {}
    # Create cache with proper dimensions
    # Need to add batch dimension and transpose conv_state as it's accessed as .transpose(-1, -2)
    max_batch_size = 128  # Max cache lines - use a reasonable number
    
    # Conv state storage shape: (batch, width-1, dim) because it's accessed as .transpose(-1, -2)
    # When transposed it becomes (batch, dim, width-1) which is what causal_conv1d_fn expects
    conv_state_storage = torch.zeros(
        (max_batch_size, conv_state_shape[0], conv_state_shape[1]), 
        dtype=conv_state_dtype, 
        device=config.device
    )
    ssm_state = torch.zeros(
        (max_batch_size,) + ssm_state_shape, 
        dtype=ssm_state_dtype, 
        device=config.device
    )
    model.kv_cache[0] = [conv_state_storage, ssm_state]
    
    return model, prefix, vllm_config


def run_benchmark(config: BenchmarkConfig):
    """Run the benchmark"""
    print(f"Setting up model...")
    model, prefix, vllm_config = setup_model(config)
    
    # Create input tensors
    hidden_states = torch.randn(
        config.batch_size, config.hidden_size, 
        dtype=config.dtype, device=config.device
    )
    output = torch.zeros_like(hidden_states)
    
    # Create mock metadata
    metadata = create_mock_metadata(config.batch_size, config.device)
    
    # Create forward context with no_compile_layers
    forward_context = ForwardContext(
        attn_metadata={prefix: metadata},
        virtual_engine=0,
        no_compile_layers={prefix: model},
    )
    
    print(f"Running warmup iterations ({config.num_warmup_iters})...")
    # Warmup
    with torch.no_grad():
        for _ in range(config.num_warmup_iters):
            with set_forward_context(forward_context.attn_metadata, vllm_config, 
                                   virtual_engine=forward_context.virtual_engine):
                model._forward(hidden_states, output)
    
    # Synchronize before timing
    if config.device == "cuda":
        torch.cuda.synchronize()
    
    print(f"Running benchmark iterations ({config.num_bench_iters})...")
    # Benchmark
    times = []
    with torch.no_grad():
        for _ in range(config.num_bench_iters):
            if config.device == "cuda":
                torch.cuda.synchronize()
            
            start_time = time.time()
            with set_forward_context(forward_context.attn_metadata, vllm_config,
                                   virtual_engine=forward_context.virtual_engine):
                model._forward(hidden_states, output)
            
            if config.device == "cuda":
                torch.cuda.synchronize()
            
            end_time = time.time()
            times.append(end_time - start_time)
    
    # Calculate statistics
    times = np.array(times)
    mean_time = np.mean(times)
    std_time = np.std(times)
    min_time = np.min(times)
    max_time = np.max(times)
    median_time = np.median(times)
    
    print("\n" + "="*50)
    print("Benchmark Results")
    print("="*50)
    print(f"Configuration:")
    print(f"  - Batch size: {config.batch_size}")
    print(f"  - Hidden size: {config.hidden_size}")
    print(f"  - Device: {config.device}")
    print(f"  - Dtype: {config.dtype}")
    print(f"  - Iterations: {config.num_bench_iters}")
    print(f"\nTiming Statistics:")
    print(f"  - Mean time: {mean_time*1000:.3f} ms")
    print(f"  - Std dev: {std_time*1000:.3f} ms")
    print(f"  - Min time: {min_time*1000:.3f} ms")
    print(f"  - Max time: {max_time*1000:.3f} ms")
    print(f"  - Median time: {median_time*1000:.3f} ms")
    print(f"\nThroughput:")
    print(f"  - Tokens/sec: {config.batch_size / mean_time:.1f}")
    print("="*50)


def main():
    """Main function"""
    import argparse
    
    parser = argparse.ArgumentParser(description="Benchmark Qwen3NextGatedDeltaNet._forward")
    parser.add_argument("--batch-size", type=int, default=8192, help="Batch size")
    parser.add_argument("--warmup", type=int, default=10, help="Number of warmup iterations")
    parser.add_argument("--iterations", type=int, default=100, help="Number of benchmark iterations")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
    parser.add_argument("--test", action="store_true", help="Run a quick test with small batch size")
    args = parser.parse_args()
    
    # Activate the virtual environment if needed
    venv_path = "/home/vgimpelson/1/venv_qwen2/bin/activate"
    if os.path.exists(venv_path):
        print(f"Note: You may need to activate the virtual environment: source {venv_path}")
    
    # Create benchmark configuration
    if args.test:
        print("Running in test mode with reduced parameters...")
        config = BenchmarkConfig(
            batch_size=32,
            hidden_size=2048,
            num_warmup_iters=2,
            num_bench_iters=5,
            device=args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu",
            dtype=torch.bfloat16
        )
    else:
        config = BenchmarkConfig(
            batch_size=args.batch_size,
            hidden_size=2048,
            num_warmup_iters=args.warmup,
            num_bench_iters=args.iterations,
            device=args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu",
            dtype=torch.bfloat16
        )
    
    # Check if CUDA is available
    if not torch.cuda.is_available() and config.device == "cuda":
        print("WARNING: CUDA is not available, falling back to CPU")
        config.device = "cpu"
    
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    
    try:
        # Run the benchmark
        run_benchmark(config)
    except Exception as e:
        print(f"\nError occurred: {e}")
        import traceback
        traceback.print_exc()
        print("\nTip: Make sure you have activated the virtual environment and all dependencies are installed.")


if __name__ == "__main__":
    main()

Covering Qwen3NextGatedDeltaNet._forward with torch.compile(mode=“max-autotune-no-cudagraphs”) gaves 14% perf improvement. That corresponds to 4-5% E2E perf on Qwen3-next.

Metadata

Metadata

Assignees

No one assigned

    Labels

    performancePerformance-related issues

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions