Skip to content

[Performance]: non-optimal performance of linear for small batches #27173

@vadiklyutiy

Description

@vadiklyutiy

Proposal to improve performance

Originally found on Qwen3-next with small batch sizes but should be actual for another models.

For batch sizes <= 32 torch's linear implementation isn't optimal. Below is comparison with implementation in FlashInfer. I used input/output features num from actual gemm inside GDN attn, Qwen3-next. B200 GPU.

batch=1
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         11.857920       0.017042     1414.85      1415.89      1.00x
2. torch.compile()                  11.852800       0.010392     1415.46      1416.50      1.00x
3. max-autotune ncg                 4.402560        0.007621     3810.79      3813.58      2.69x
4. TGV GEMM pdl=False               6.794880        0.009877     2469.10      2470.91      1.75x
5. TGV GEMM pdl=True                6.135360        0.014390     2734.51      2736.51      1.93x

batch=2
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         10.817280       0.012709     3101.93      1553.24      1.00x
2. torch.compile()                  10.800641       0.009125     3106.71      1555.63      1.00x
3. max-autotune ncg                 10.811200       0.012217     3103.67      1554.11      1.00x
4. TGV GEMM pdl=False               6.807360        0.004601     4929.14      2468.18      1.59x
5. TGV GEMM pdl=True                6.158080        0.008603     5448.85      2728.41      1.76x

batch=4
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         10.827520       0.014918     6197.99      1554.04      1.00x
2. torch.compile()                  10.795200       0.017265     6216.55      1558.69      1.00x
3. max-autotune ncg                 10.801920       0.013969     6212.68      1557.72      1.00x
4. TGV GEMM pdl=False               6.810560        0.008674     9853.65      2470.63      1.59x
5. TGV GEMM pdl=True                6.157760        0.005177     10898.26     2732.55      1.76x

batch=8
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         10.867200       0.018683     12350.72     1552.89      1.00x
2. torch.compile()                  10.858560       0.012591     12360.55     1554.12      1.00x
3. max-autotune ncg                 10.853760       0.012731     12366.01     1554.81      1.00x
4. TGV GEMM pdl=False               6.810880        0.009117     19706.37     2477.73      1.60x
5. TGV GEMM pdl=True                6.153600        0.007368     21811.25     2742.38      1.77x

batch=16
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.090240        0.012524     33180.16     2098.06      1.00x
2. torch.compile()                  8.095040        0.007869     33160.49     2096.82      1.00x
3. max-autotune ncg                 8.110400        0.009816     33097.68     2092.85      1.00x
4. TGV GEMM pdl=False               6.833920        0.012704     39279.87     2483.76      1.18x
5. TGV GEMM pdl=True                6.172800        0.008291     43486.82     2749.78      1.31x

batch=32
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.128000        0.007976     66052.03     2112.50      1.00x
2. torch.compile()                  8.121600        0.011911     66104.08     2114.17      1.00x
3. max-autotune ncg                 8.141440        0.020165     65942.99     2109.02      1.00x
4. TGV GEMM pdl=False               7.436160        0.005175     72197.33     2309.05      1.09x
5. TGV GEMM pdl=True                6.739840        0.016613     79656.33     2547.60      1.21x

batch=64
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.212160        0.013407     130750.23    2138.74      1.00x
2. torch.compile()                  8.222400        0.034535     130587.40    2136.07      1.00x
3. max-autotune ncg                 8.216000        0.016302     130689.12    2137.74      1.00x
4. TGV GEMM pdl=False               13.780160       0.017629     77919.40     1274.56      0.60x
5. TGV GEMM pdl=True                13.060160       0.030553     82215.06     1344.83      0.63x

I used the following script to measure performance

benchmark_linear.py
#!/usr/bin/env python3
"""
Benchmark for torch.nn.functional.linear on NVIDIA GPU
Compares five variants:
1. Original torch.nn.functional.linear
2. With torch.compile
3. With torch.compile(mode="max-autotune-no-cudagraphs")
4. TGV GEMM (FlashInfer SM100)
5. TGV GEMM (FlashInfer SM100) with pdl=True

Uses bfloat16 precision
Measures performance using CUDA events with 100 repetitions
Uses CUDA graphs to capture and replay the 100 iterations
Each benchmark runs 5 times and reports median and std
By default, weight buffers are cloned for each iteration to simulate different memory reads
(can be disabled with use_separate_weight_buffers=False)
"""

import torch
import torch.nn.functional as F
import nvtx
from flashinfer import tgv_gemm_sm100
import statistics


def run_benchmark(linear_func, name, x, weight, bias, num_iterations=100, warmup=10, use_separate_weight_buffers=True):
    """Run benchmark for a given linear function
    
    Args:
        use_separate_weight_buffers: If True, creates separate weight buffer for each iteration in CUDA graph.
                                      If False, uses the same weight buffer for all iterations.
    """
    
    # Warm-up iterations
    for _ in range(warmup):
        _ = linear_func(x, weight, bias)
    
    torch.cuda.synchronize()
    
    if use_separate_weight_buffers:
        # Create multiple weight buffers - one for each iteration
        weight_buffers = []
        for _ in range(num_iterations):
            # Create a new weight buffer with the same data but different memory location
            weight_buffer = weight.clone().detach()
            weight_buffers.append(weight_buffer)
    else:
        # Use the same weight for all iterations
        weight_buffers = [weight] * num_iterations
    
    # Create a static buffer for output to ensure graph can be captured
    static_output = None
    
    # Capture the graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for i in range(num_iterations):
            # Use weight buffer for each iteration
            static_output = linear_func(x, weight_buffers[i], bias)
    
    torch.cuda.synchronize()
    
    # Warmup CUDA graph - replay 3 times
    for _ in range(3):
        graph.replay()
    torch.cuda.synchronize()
    
    # Benchmark with CUDA events - replay the graph 5 times
    num_benchmark_runs = 5
    elapsed_times_us = []
    
    for _ in range(num_benchmark_runs):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        start_event.record()
        graph.replay()
        end_event.record()
        
        torch.cuda.synchronize()
        
        elapsed_time_ms = start_event.elapsed_time(end_event)
        elapsed_times_us.append(elapsed_time_ms * 1000)  # Convert to microseconds
    
    output = static_output
    
    # Clean up weight buffers (only if we created separate ones)
    if use_separate_weight_buffers:
        del weight_buffers
    
    # Calculate median and std of elapsed times
    median_time_us = statistics.median(elapsed_times_us)
    std_time_us = statistics.stdev(elapsed_times_us) if len(elapsed_times_us) > 1 else 0.0
    avg_time_us = median_time_us / num_iterations
    
    # Calculate FLOPs
    batch_size = x.shape[0]
    in_features = x.shape[1]
    out_features = weight.shape[0]
    flops_per_iteration = 2 * batch_size * out_features * in_features
    gflops = (flops_per_iteration / (avg_time_us * 1e-6)) / 1e9
    
    # Calculate Memory Bandwidth
    # Bytes read/written per iteration:
    # - Read x: batch_size * in_features * dtype_size
    # - Read weight: out_features * in_features * dtype_size
    # - Write output: batch_size * out_features * dtype_size
    # - Read bias (if present): out_features * dtype_size
    dtype_size = x.element_size()  # bytes per element
    bytes_read_x = batch_size * in_features * dtype_size
    bytes_read_weight = out_features * in_features * dtype_size
    bytes_write_output = batch_size * out_features * dtype_size
    bytes_read_bias = out_features * dtype_size if bias is not None else 0
    
    total_bytes = bytes_read_x + bytes_read_weight + bytes_write_output + bytes_read_bias
    bandwidth_gb_s = (total_bytes / (avg_time_us * 1e-6)) / 1e9
    
    # Calculate std per iteration
    std_time_per_iter_us = std_time_us / num_iterations
    
    return avg_time_us, std_time_per_iter_us, gflops, bandwidth_gb_s


def benchmark_linear():
    """Benchmark torch.nn.functional.linear with different batch sizes"""
    
    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA is not available. This benchmark requires an NVIDIA GPU.")
        return
    
    device = torch.device("cuda")
    in_features = 4096
    out_features = 2048
    dtype = torch.bfloat16
    
    # Create weight matrix (shared across all batch sizes)
    weight = torch.randn(out_features, in_features, device=device, dtype=dtype)
    bias = None
    
    # Define functions
    def linear_original(x, weight, bias):
        return F.linear(x, weight, bias)
    
    # Compile functions once
    linear_compiled = torch.compile(linear_original)
    linear_max_autotune = torch.compile(linear_original, mode="max-autotune-no-cudagraphs")
    
    # Prepare TGV GEMM weight
    weight_tgv = weight.clone().contiguous().t()
    bias_tgv = None
    
    def linear_tgv(x, weight, bias):
        return tgv_gemm_sm100(x, weight_tgv, bias_tgv, pdl=False)
    
    def linear_tgv_pdl(x, weight, bias):
        return tgv_gemm_sm100(x, weight_tgv, bias_tgv, pdl=True)
    
    # Initial warmup
    x_warmup = torch.randn(8192, in_features, device=device, dtype=dtype)
    torch._dynamo.mark_dynamic(x_warmup, 0)

    for func in [linear_original, linear_compiled, linear_max_autotune, linear_tgv, linear_tgv_pdl]:
        _ = func(x_warmup, weight, bias)
    torch.cuda.synchronize()
    del x_warmup
    
    # Batch sizes to test (powers of 2 from 1 to 4096)
    batch_sizes = [2**i for i in range(7)]  # 1, 2, 4, 8, ..., 4096
    
    for batch_size in batch_sizes:
        # Create input tensor for this batch size
        x = torch.randn(batch_size, in_features, device=device, dtype=dtype)
        
        # Run benchmarks
        # Note: Add use_separate_weight_buffers=False to use same weight buffer for all iterations
        time1, std1, gflops1, bw1 = run_benchmark(
            linear_original, 
            "1. Original",
            x, weight, bias
        )
        
        time2, std2, gflops2, bw2 = run_benchmark(
            linear_compiled,
            "2. torch.compile()",
            x, weight, bias
        )
        
        time3, std3, gflops3, bw3 = run_benchmark(
            linear_max_autotune,
            "3. max-autotune ncg",
            x, weight, bias
        )
        
        time4, std4, gflops4, bw4 = run_benchmark(
            linear_tgv,
            "4. TGV GEMM pdl=False",
            x, weight, bias
        )
        
        time5, std5, gflops5, bw5 = run_benchmark(
            linear_tgv_pdl,
            "5. TGV GEMM pdl=True",
            x, weight, bias
        )
        
        # Print summary
        print(f"batch={batch_size}")
        print(f"{'='*110}")
        print("SUMMARY COMPARISON")
        print(f"{'='*110}")
        print(f"{'Variant':<35} {'Median (us)':<15} {'Std (us)':<12} {'GFLOPS':<12} {'BW (GB/s)':<12} {'Speedup'}")
        print(f"{'-'*110}")
        print(f"{'1. Original':<35} {time1:<15.6f} {std1:<12.6f} {gflops1:<12.2f} {bw1:<12.2f} {1.0:.2f}x")
        print(f"{'2. torch.compile()':<35} {time2:<15.6f} {std2:<12.6f} {gflops2:<12.2f} {bw2:<12.2f} {time1/time2:.2f}x")
        print(f"{'3. max-autotune ncg':<35} {time3:<15.6f} {std3:<12.6f} {gflops3:<12.2f} {bw3:<12.2f} {time1/time3:.2f}x")
        print(f"{'4. TGV GEMM pdl=False':<35} {time4:<15.6f} {std4:<12.6f} {gflops4:<12.2f} {bw4:<12.2f} {time1/time4:.2f}x")
        print(f"{'5. TGV GEMM pdl=True':<35} {time5:<15.6f} {std5:<12.6f} {gflops5:<12.2f} {bw5:<12.2f} {time1/time5:.2f}x")
        print()  # Empty line between batch sizes
        
        del x


if __name__ == "__main__":
    benchmark_linear()

Solution

I see 2 ways to solve it:

  1. Torch has tuning function. Might be fixed there
  2. Implement in FlashInfer support of gemm fp16 (right now there is no such function)

Metadata

Metadata

Assignees

No one assigned

    Labels

    performancePerformance-related issues

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions