-
-
Couldn't load subscription status.
- Fork 10.8k
Open
Labels
performancePerformance-related issuesPerformance-related issues
Description
Proposal to improve performance
Right now GDN attn (Qwen3NextGatedDeltaNet) (used in Qwen3-next) aren't covered by torch.compile.
GDN unlike full attn contain a lot of operators including a lot of elementwise. Below is illustration
Right now GDN attn implemented as custom op and torch.compile doesn't go inside.
I wrote the following script that call vllm's GDN attn and measured performance.
Benchmark script
#!/usr/bin/env python3
"""
Standalone benchmark for Qwen3NextGatedDeltaNet._forward method
"""
import os
import sys
import time
import torch
import numpy as np
from dataclasses import dataclass
from typing import Optional, Dict, Any
# Add the current directory to Python path to import vllm modules
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Import necessary vLLM components
from vllm.config import ModelConfig, CacheConfig, VllmConfig
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
from vllm.forward_context import ForwardContext, set_forward_context
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.platforms import current_platform
from vllm.attention.backends.utils import PAD_SLOT_ID
@dataclass
class BenchmarkConfig:
"""Configuration for the benchmark"""
batch_size: int = 8192
hidden_size: int = 2048
num_warmup_iters: int = 10
num_bench_iters: int = 100
device: str = "cuda"
dtype: torch.dtype = torch.bfloat16
def create_qwen3_config():
"""Create Qwen3NextConfig with the exact parameters provided"""
config_dict = {
"architectures": ["Qwen3NextForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"decoder_sparse_step": 1,
"dtype": "bfloat16",
"eos_token_id": 151645,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_types": [
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention",
"linear_attention", "linear_attention", "linear_attention", "full_attention"
],
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 32,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_next",
"moe_intermediate_size": 512,
"norm_topk_prob": True,
"num_attention_heads": 16,
"num_experts": 512,
"num_experts_per_tok": 10,
"num_hidden_layers": 48,
"num_key_value_heads": 2,
"output_router_logits": False,
"partial_rotary_factor": 0.25,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 10000000,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 512,
"tie_word_embeddings": False,
"transformers_version": "4.56.1",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936,
"torch_dtype": torch.bfloat16
}
return Qwen3NextConfig(**config_dict)
def create_mock_metadata(batch_size: int, device: str):
"""Create mock GDNAttentionMetadata for testing"""
# Create metadata with all required fields
# Create nums_dict for causal_conv1d
query_start_loc = torch.tensor([0, batch_size], dtype=torch.int32, device='cpu')
seqlens = query_start_loc.diff() # Single sequence of length batch_size
nums_dict = {}
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M) # Ceiling division
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = []
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
# Create batch_ptr and token_chunk_offset_ptr
batch_ptr = torch.full((MAX_NUM_PROGRAMS,), PAD_SLOT_ID,
dtype=torch.int32, device=device)
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS,), PAD_SLOT_ID,
dtype=torch.int32, device=device)
batch_ptr[0:mlist_len].copy_(mlist.to(device))
token_chunk_offset_ptr[0:mlist_len].copy_(offsetlist.to(device))
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = token_chunk_offset_ptr
metadata = GDNAttentionMetadata(
num_prefills=1, # Treating all as one prefill request
num_prefill_tokens=batch_size,
num_decodes=0,
num_decode_tokens=0,
num_spec_decodes=0,
num_spec_decode_tokens=0,
num_actual_tokens=batch_size,
has_initial_state=torch.zeros(1, dtype=torch.bool, device=device), # No initial state for prefill
spec_query_start_loc=None,
non_spec_query_start_loc=torch.tensor([0, batch_size], dtype=torch.int32, device=device),
spec_state_indices_tensor=None,
non_spec_state_indices_tensor=torch.zeros(1, dtype=torch.int32, device=device),
spec_sequence_masks=None,
spec_token_indx=None,
non_spec_token_indx=None,
num_accepted_tokens=None,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return metadata
def setup_model(config: BenchmarkConfig):
"""Initialize the model and required components"""
# Initialize distributed environment (single GPU)
if not torch.distributed.is_initialized():
# Set required environment variables for single GPU setup
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
init_distributed_environment()
initialize_model_parallel(1, 1)
# Create configurations
qwen_config = create_qwen3_config()
# Create model config
model_config = ModelConfig(
model='Qwen/Qwen3-Next-80B-A3B-Instruct',
tokenizer='Qwen/Qwen3-Next-80B-A3B-Instruct',
tokenizer_mode='auto',
trust_remote_code=False,
dtype=config.dtype,
max_model_len=262144,
)
# Create cache config
cache_config = CacheConfig(
block_size=272,
gpu_memory_utilization=0.9,
swap_space=4.0,
cache_dtype='auto',
mamba_page_size_padded=278528,
mamba_cache_dtype='auto',
mamba_ssm_cache_dtype='auto',
)
# Create VllmConfig with compilation config that has static_forward_context
from vllm.config import CompilationConfig, ParallelConfig, SchedulerConfig, VllmConfig as VllmConfigClass
compilation_config = CompilationConfig()
compilation_config.static_forward_context = {}
parallel_config = ParallelConfig(
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
scheduler_config = SchedulerConfig()
# Create VllmConfig properly
vllm_config = VllmConfigClass(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
compilation_config=compilation_config,
)
# Set the current vllm config
from vllm.config import set_current_vllm_config
set_current_vllm_config(vllm_config)
# Initialize the model
prefix = "model.layers.0.linear_attn"
model = Qwen3NextGatedDeltaNet(
config=qwen_config,
model_config=model_config,
cache_config=cache_config,
quant_config=None,
speculative_config=None,
prefix=prefix
)
# Move model to device and dtype
model = model.to(config.device, dtype=config.dtype)
model.eval()
# Ensure all parameters are in the correct dtype
# Note: A_log stays in float32 per design
for name, param in model.named_parameters():
if "A_log" not in name:
param.data = param.data.to(config.dtype)
# Initialize KV cache
conv_state_shape, ssm_state_shape = model.get_state_shape()
conv_state_dtype, ssm_state_dtype = model.get_state_dtype()
print(f"Conv state shape (per sequence): {conv_state_shape}")
print(f"SSM state shape (per sequence): {ssm_state_shape}")
# Assuming single virtual engine
model.kv_cache = {}
# Create cache with proper dimensions
# Need to add batch dimension and transpose conv_state as it's accessed as .transpose(-1, -2)
max_batch_size = 128 # Max cache lines - use a reasonable number
# Conv state storage shape: (batch, width-1, dim) because it's accessed as .transpose(-1, -2)
# When transposed it becomes (batch, dim, width-1) which is what causal_conv1d_fn expects
conv_state_storage = torch.zeros(
(max_batch_size, conv_state_shape[0], conv_state_shape[1]),
dtype=conv_state_dtype,
device=config.device
)
ssm_state = torch.zeros(
(max_batch_size,) + ssm_state_shape,
dtype=ssm_state_dtype,
device=config.device
)
model.kv_cache[0] = [conv_state_storage, ssm_state]
return model, prefix, vllm_config
def run_benchmark(config: BenchmarkConfig):
"""Run the benchmark"""
print(f"Setting up model...")
model, prefix, vllm_config = setup_model(config)
# Create input tensors
hidden_states = torch.randn(
config.batch_size, config.hidden_size,
dtype=config.dtype, device=config.device
)
output = torch.zeros_like(hidden_states)
# Create mock metadata
metadata = create_mock_metadata(config.batch_size, config.device)
# Create forward context with no_compile_layers
forward_context = ForwardContext(
attn_metadata={prefix: metadata},
virtual_engine=0,
no_compile_layers={prefix: model},
)
print(f"Running warmup iterations ({config.num_warmup_iters})...")
# Warmup
with torch.no_grad():
for _ in range(config.num_warmup_iters):
with set_forward_context(forward_context.attn_metadata, vllm_config,
virtual_engine=forward_context.virtual_engine):
model._forward(hidden_states, output)
# Synchronize before timing
if config.device == "cuda":
torch.cuda.synchronize()
print(f"Running benchmark iterations ({config.num_bench_iters})...")
# Benchmark
times = []
with torch.no_grad():
for _ in range(config.num_bench_iters):
if config.device == "cuda":
torch.cuda.synchronize()
start_time = time.time()
with set_forward_context(forward_context.attn_metadata, vllm_config,
virtual_engine=forward_context.virtual_engine):
model._forward(hidden_states, output)
if config.device == "cuda":
torch.cuda.synchronize()
end_time = time.time()
times.append(end_time - start_time)
# Calculate statistics
times = np.array(times)
mean_time = np.mean(times)
std_time = np.std(times)
min_time = np.min(times)
max_time = np.max(times)
median_time = np.median(times)
print("\n" + "="*50)
print("Benchmark Results")
print("="*50)
print(f"Configuration:")
print(f" - Batch size: {config.batch_size}")
print(f" - Hidden size: {config.hidden_size}")
print(f" - Device: {config.device}")
print(f" - Dtype: {config.dtype}")
print(f" - Iterations: {config.num_bench_iters}")
print(f"\nTiming Statistics:")
print(f" - Mean time: {mean_time*1000:.3f} ms")
print(f" - Std dev: {std_time*1000:.3f} ms")
print(f" - Min time: {min_time*1000:.3f} ms")
print(f" - Max time: {max_time*1000:.3f} ms")
print(f" - Median time: {median_time*1000:.3f} ms")
print(f"\nThroughput:")
print(f" - Tokens/sec: {config.batch_size / mean_time:.1f}")
print("="*50)
def main():
"""Main function"""
import argparse
parser = argparse.ArgumentParser(description="Benchmark Qwen3NextGatedDeltaNet._forward")
parser.add_argument("--batch-size", type=int, default=8192, help="Batch size")
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup iterations")
parser.add_argument("--iterations", type=int, default=100, help="Number of benchmark iterations")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
parser.add_argument("--test", action="store_true", help="Run a quick test with small batch size")
args = parser.parse_args()
# Activate the virtual environment if needed
venv_path = "/home/vgimpelson/1/venv_qwen2/bin/activate"
if os.path.exists(venv_path):
print(f"Note: You may need to activate the virtual environment: source {venv_path}")
# Create benchmark configuration
if args.test:
print("Running in test mode with reduced parameters...")
config = BenchmarkConfig(
batch_size=32,
hidden_size=2048,
num_warmup_iters=2,
num_bench_iters=5,
device=args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu",
dtype=torch.bfloat16
)
else:
config = BenchmarkConfig(
batch_size=args.batch_size,
hidden_size=2048,
num_warmup_iters=args.warmup,
num_bench_iters=args.iterations,
device=args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu",
dtype=torch.bfloat16
)
# Check if CUDA is available
if not torch.cuda.is_available() and config.device == "cuda":
print("WARNING: CUDA is not available, falling back to CPU")
config.device = "cpu"
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
try:
# Run the benchmark
run_benchmark(config)
except Exception as e:
print(f"\nError occurred: {e}")
import traceback
traceback.print_exc()
print("\nTip: Make sure you have activated the virtual environment and all dependencies are installed.")
if __name__ == "__main__":
main()
Covering Qwen3NextGatedDeltaNet._forward with torch.compile(mode=“max-autotune-no-cudagraphs”) gaves 14% perf improvement. That corresponds to 4-5% E2E perf on Qwen3-next.
Metadata
Metadata
Assignees
Labels
performancePerformance-related issuesPerformance-related issues