Skip to content

speed of light analysis #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 136 additions & 5 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import warnings

import torch
from torch.utils.flop_counter import FlopCounterMode

import triton.testing

try:
import triton.testing
except ImportError:
triton = None

from BackendBench.utils import uses_cuda_stream

Expand Down Expand Up @@ -82,21 +86,148 @@ def cpu_bench(fn, num_runs=100):
return (time.perf_counter() - start) / num_runs


# First value is maximum theoretical flops for FP16
# Second value is maximum theoretical memory bandwidth across all SKUs in that generation
# Sources:
# T4: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
# V100: https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf
# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
# H100: https://resources.nvidia.com/en-us-gpu-resources/h100-datasheet-24306
GPU_SPECS = {
"t4": (65e12, 300e9),
"v100": (112e12, 900e9),
"a100": (312e12, 2039e9),
"h100": (1979e12, 3350e9),
}

FALLBACK_GPU_SPECS = (500e12, 1000e9)
CPU_FALLBACK_SPECS = (10e12, 100e9)


def get_gpu_specs():
if not torch.cuda.is_available():
return CPU_FALLBACK_SPECS

props = torch.cuda.get_device_properties(0)
gpu_name = props.name.lower()

for gpu_key, specs in GPU_SPECS.items():
if gpu_key in gpu_name:
compute_peak, memory_bw = specs
logger.debug(
f"Detected {gpu_name}, using {compute_peak / 1e12:.0f} TFLOP/s, {memory_bw / 1e9:.0f} GB/s"
)
return specs

logger.debug(f"Unknown GPU {gpu_name}, using fallback 500 TFLOP/s, 1000 GB/s")
return FALLBACK_GPU_SPECS


def calculate_tensor_memory_bytes(args, kwargs):
total_bytes = 0
all_values = list(args) + list(kwargs.values())
for value in all_values:
if isinstance(value, torch.Tensor):
total_bytes += value.numel() * value.element_size()
return total_bytes


def calculate_memory_bandwidth_limit(args, kwargs, runtime_ms):
runtime_s = runtime_ms / 1000.0
total_bytes = calculate_tensor_memory_bytes(args, kwargs)
return total_bytes / runtime_s


def calculate_efficiency_metrics(op, args, kwargs, runtime_ms):
flop_counter = FlopCounterMode()
with flop_counter:
op(*args, **kwargs)

total_flops = flop_counter.get_total_flops()
compute_peak, memory_bandwidth = get_gpu_specs()
runtime_s = runtime_ms / 1000.0

compute_efficiency = None
if total_flops > 0:
achieved_flops_per_s = total_flops / runtime_s
compute_efficiency = achieved_flops_per_s / compute_peak

achieved_bandwidth = calculate_memory_bandwidth_limit(args, kwargs, runtime_ms)
memory_efficiency = achieved_bandwidth / memory_bandwidth

return compute_efficiency, memory_efficiency


def calculate_speed_of_light(op, args, kwargs, runtime_ms):
try:
compute_efficiency, memory_efficiency = calculate_efficiency_metrics(
op, args, kwargs, runtime_ms
)

violations = []
if compute_efficiency is not None and compute_efficiency > 1.0:
violations.append(f"compute: {compute_efficiency:.1%}")
if memory_efficiency > 1.0:
violations.append(f"memory: {memory_efficiency:.1%}")

if violations:
return f"VIOLATION: {', '.join(violations)}"

return compute_efficiency if compute_efficiency is not None else memory_efficiency
except Exception as e:
logger.debug(f"Could not calculate speed of light: {e}")
return None


def get_bench_function():
return (
triton.testing.do_bench if (torch.cuda.is_available() and triton is not None) else cpu_bench
)


def benchmark_op(bench_fn, op, args, kwargs):
return bench_fn(lambda: op(*args, **kwargs))


def log_speed_of_light_efficiency(op_name, efficiency):
if efficiency is None:
return

if isinstance(efficiency, str) and "VIOLATION" in efficiency:
warnings.warn(
f"Speed of light violation for {op_name}: {efficiency}. "
f"This indicates a measurement error - kernel may not be computing the result or timing is wrong.",
UserWarning,
)
logger.info(f"{op_name} speed of light: {efficiency}")
else:
logger.info(f"{op_name} speed of light efficiency: {efficiency:.1%}")


def eval_performance(op, impl, tests):
bench_fn = triton.testing.do_bench if torch.cuda.is_available() else cpu_bench
bench_fn = get_bench_function()
base_times = []
test_times = []

for test in tests:
logging.debug(
f"Benchmarking {op.__name__} with args {format_args(test.args)} and kwargs {format_kwargs(test.kwargs)}"
)
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
base_time = benchmark_op(bench_fn, op, test.args, test.kwargs)
base_times.append(base_time)

try:
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
except Exception:
test_times.append(base_times[-1])
continue
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))

test_time = benchmark_op(bench_fn, impl, test.args, test.kwargs)
test_times.append(test_time)

efficiency = calculate_speed_of_light(impl, test.args, test.kwargs, test_time)
log_speed_of_light_efficiency(op.__name__, efficiency)

speedups = torch.tensor(base_times) / torch.tensor(test_times)
return speedups.log().mean().exp()

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"anthropic>=0.34.0",
"pytest",
"requests",
"tabulate>=0.9.0",
]

[project.optional-dependencies]
Expand Down
134 changes: 134 additions & 0 deletions test/test_speed_of_light.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest
import torch
import warnings

from BackendBench.eval import calculate_speed_of_light, get_gpu_specs


class TestSpeedOfLight:
def test_gpu_specs_detection(self):
"""Test that GPU specs are detected correctly."""
compute_peak, memory_bw = get_gpu_specs()
assert compute_peak > 0
assert memory_bw > 0

def test_speed_of_light_realistic_performance(self):
"""Test that realistic performance doesn't trigger violations."""
# Test with matrix multiply - realistic timing
op = torch.ops.aten.mm.default
a = torch.randn(100, 100)
b = torch.randn(100, 100)
args = (a, b)
kwargs = {}

# 10ms is realistic for 100x100 matmul on CPU/GPU
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
efficiency = calculate_speed_of_light(op, args, kwargs, 10.0)

# Should not trigger any warnings
assert len(w) == 0

# Should return reasonable efficiency (not None, not violation string)
assert efficiency is not None
assert isinstance(efficiency, float)
assert 0 < efficiency < 1.0 # Should be reasonable percentage

def test_speed_of_light_compute_violation(self):
"""Test that impossible compute performance triggers violation."""
# Test with matrix multiply - impossibly fast timing
op = torch.ops.aten.mm.default
a = torch.randn(1000, 1000) # Larger matrix for more FLOPs
b = torch.randn(1000, 1000)
args = (a, b)
kwargs = {}

# 0.001ms is impossibly fast for 1000x1000 matmul (2B FLOPs)
efficiency = calculate_speed_of_light(op, args, kwargs, 0.001)

# Should trigger compute violation
assert isinstance(efficiency, str)
assert "VIOLATION" in efficiency
assert "compute" in efficiency

def test_speed_of_light_memory_violation(self):
"""Test that impossible memory bandwidth triggers violation."""
# Use ReLU which naturally has no FLOPs registered in PyTorch
large_tensor = torch.randn(10_000_000) # 40MB tensor
args = (large_tensor,)
kwargs = {}

# 0.001ms is impossibly fast for moving 40MB of data
efficiency = calculate_speed_of_light(torch.ops.aten.relu.default, args, kwargs, 0.001)

# Should trigger memory violation
assert isinstance(efficiency, str)
assert "VIOLATION" in efficiency
assert "memory" in efficiency

def test_speed_of_light_no_flops_realistic(self):
"""Test memory-bound operation with realistic timing."""
# Use ReLU which naturally has no FLOPs registered
small_tensor = torch.randn(1000) # 4KB tensor
args = (small_tensor,)
kwargs = {}

# 1ms is reasonable for small memory operations
efficiency = calculate_speed_of_light(torch.ops.aten.relu.default, args, kwargs, 1.0)

# Should return reasonable memory efficiency
assert isinstance(efficiency, float)
assert 0 < efficiency < 1.0

def test_speed_of_light_exception_handling(self):
"""Test that function handles exceptions gracefully."""
# Invalid arguments that will cause an exception
args = (5,) # scalar argument to relu (which expects tensor)
kwargs = {}

efficiency = calculate_speed_of_light(torch.ops.aten.relu.default, args, kwargs, 1.0)

# Should return None when operation fails
assert efficiency is None


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU tests require CUDA")
class TestSpeedOfLightGPU:
def test_t4_detection(self):
"""Test that T4 GPU is detected correctly in CI."""
props = torch.cuda.get_device_properties(0)
gpu_name = props.name.lower()

compute_peak, memory_bw = get_gpu_specs()

if "t4" in gpu_name:
# Should detect T4 specs
assert compute_peak == 65e12 # 65 TFLOPS
assert memory_bw == 300e9 # 300 GB/s
else:
# Unknown GPU should use fallback
assert compute_peak == 500e12
assert memory_bw == 1000e9

def test_gpu_realistic_matmul(self):
"""Test realistic GPU matrix multiply performance."""
# Move tensors to GPU
a = torch.randn(512, 512, device="cuda")
b = torch.randn(512, 512, device="cuda")
args = (a, b)
kwargs = {}

# Warm up
for _ in range(5):
torch.mm(a, b)
torch.cuda.synchronize()

# 1ms should be reasonable for 512x512 matmul on GPU
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
efficiency = calculate_speed_of_light(torch.ops.aten.mm.default, args, kwargs, 1.0)

# Should not trigger violations on real GPU timing
assert len(w) == 0
assert isinstance(efficiency, float)
assert 0 < efficiency < 1.0