Skip to content

NVfp4 #2408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged

NVfp4 #2408

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 137 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
Expand All @@ -25,7 +26,11 @@
MXInferenceLinear,
MXLinear,
)
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
from torchao.prototype.mx_formats.mx_subclass import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
Expand Down Expand Up @@ -404,6 +409,7 @@ def test_inference_print_str():
@skip_if_rocm(
"ROCm float4 gemm require gfx950"
) # TODO(future): deploy gfx950 in ROCM CI
@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required")
def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
"""
Smoke test for inference compile
Expand Down Expand Up @@ -441,3 +447,133 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize(
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
)
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
@torch.no_grad()
@skip_if_rocm("ROCm float4 gemm require gfx950")
def test_inference_subclass_nvfp4(
bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype
):
"""
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
"""
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")

if bias and inpt_dtype == torch.float32:
pytest.xfail("Bias is not supported when module weight is in fp32")

if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda")
m_mx = copy.deepcopy(m)

config = NVFP4InferenceConfig(mm_config=mm_config)
quantize_(m_mx, config=config)

if compile:
m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager")

x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype)
y_ref = m(x)
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)

if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
SQNR_THRESHOLD = 18.0
else:
SQNR_THRESHOLD = 15.0

assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}"
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.parametrize("use_gelu", [True, False])
@pytest.mark.parametrize(
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
)
@pytest.mark.parametrize("compile", [False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
@torch.no_grad()
@skip_if_rocm("ROCm float4 gemm require gfx950")
def test_nvfp4_matmul_with_amax(
use_gelu: bool,
mm_config: NVFP4MMConfig,
compile: bool,
bias: bool,
inpt_dtype: torch.dtype,
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
per_tensor_amax_to_scale,
)

# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")

if bias and inpt_dtype == torch.float32:
pytest.xfail("Bias is not supported when module weight is in fp32")

if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")

m, k, n = 64, 256, 128

# Create activation tensor
if use_gelu:
x = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
A = torch.nn.functional.gelu(x)
else:
A = torch.randn(m, k, dtype=inpt_dtype, device="cuda")

B = torch.randn(n, k, dtype=inpt_dtype, device="cuda")
bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None

# Compute reference
C_ref = F.linear(A, B, bias_tensor)

a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
A_nvfp4 = NVFP4Tensor.to_nvfp4(
A,
per_tensor_scale=a_scale,
mm_config=mm_config,
)
B_nvfp4 = NVFP4Tensor.to_nvfp4(
B,
per_tensor_scale=b_scale,
mm_config=mm_config,
)

func = torch.compile(F.linear, fullgraph=True) if compile else F.linear

C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor)
assert C_nvfp4.dtype == inpt_dtype, (
f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}"
)

sqnr = compute_error(C_ref, C_nvfp4)
SQNR_THRESHOLD = 16.0
assert sqnr >= SQNR_THRESHOLD, (
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
)
66 changes: 66 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchao.prototype.mx_formats.constants import (
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
F4_E2M1_MAX,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
Expand Down Expand Up @@ -591,3 +592,68 @@ def to_f8(x):
torch.testing.assert_close(
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
)


@pytest.mark.parametrize(
"dtype,shape,use_per_tensor_scale",
[
(torch.bfloat16, (32, 64), False),
(torch.float32, (64, 128), False),
(torch.bfloat16, (128, 256), False),
(torch.bfloat16, (64, 128), True),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
per_tensor_amax_to_scale,
)

x = torch.randn(shape, dtype=dtype, device="cuda")
if use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(x))
scale = per_tensor_amax_to_scale(tensor_amax)
else:
scale = None

x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
x_reconstructed = x_nvfp4.to_dtype(dtype)

def assert_sqnr_gt_threshold(orig, new, threshold):
sqnr = compute_error(orig, new)
if torch.all(torch.isnan(sqnr)):
# if both operands are full of zeroes, sqnr is nan and this is ok
# test for this explicitly
assert torch.all(orig == 0) and torch.all(new == 0)
else:
assert sqnr >= threshold

reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
max_abs = torch.amax(
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
).unsqueeze(-1)

assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)

assert x.shape == x_reconstructed.shape, (
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
)
assert x.dtype == x_reconstructed.dtype, (
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
)

x_nvfp4_t = x_nvfp4.t()
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)

assert x.t().shape == x_reconstructed_t.shape, (
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
)
assert x.t().dtype == x_reconstructed_t.dtype, (
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
)
8 changes: 7 additions & 1 deletion torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
)

# Note: Prototype and subject to change
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
from torchao.prototype.mx_formats.mx_subclass import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)

# import mx_linear here to register the quantize_ transform logic
# ruff: noqa: I001
Expand All @@ -18,4 +22,6 @@
"MXLinearConfig",
"MXLinearRecipeName",
"MXFPInferenceConfig",
"NVFP4InferenceConfig",
"NVFP4MMConfig",
]
6 changes: 3 additions & 3 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
assert block_size == 32, (
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
assert block_size in [16, 32], (
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn]
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
Expand Down
73 changes: 64 additions & 9 deletions torchao/prototype/mx_formats/mx_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

import torchao
from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats import (
MXGemmKernelChoice,
Expand All @@ -20,13 +19,19 @@
_validate_gemm_kernel_choice,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor
from torchao.quantization.quant_api import to_linear_activation_quantized
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_100,
)


# TODO The naming for these configs is a little weird, rename before moving to public API
# Note: This API is extra prototype and will change in the future
@dataclass
class MXFPInferenceConfig(AOBaseConfig):
Expand Down Expand Up @@ -63,16 +68,13 @@ class MXFPInferenceConfig(AOBaseConfig):

block_size: int = 32

# Dtypes for Input and Weights
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
activation_dtype: torch.dtype = torch.float8_e4m3fn
weight_dtype: torch.dtype = torch.float8_e4m3fn

# Which kernel to run for mm
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS

# Set some magic perf settings
set_inductor_config: bool = False

def __post_init__(self):
assert self.activation_dtype == self.weight_dtype, (
"For now - we only support matching input/weight dtypes."
Expand Down Expand Up @@ -115,8 +117,6 @@ def _mx_inference_linear_transform(
# TODO Sm120 has slightly more restrictive reqs
# TODO handle AMD
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
Expand Down Expand Up @@ -151,7 +151,62 @@ def _mx_inference_linear_transform(
return module


@dataclass
class NVFP4InferenceConfig(AOBaseConfig):
"""
NVIDIA FP4 (NVFP4) Inference Quantization Configuration

This is a specialized configuration for NVIDIA's FP4 format.
All parameters are fixed in the NVFP4 implementation except mm_config:
- mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision)
- Data: float4_e2m1fn_x2
- Scales: float8_e4m3fn
- Block size: 16 along the reduction dim
"""

mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC

def __post_init__(self):
# Validate PyTorch version
if not TORCH_VERSION_AT_LEAST_2_8:
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")


@register_quantize_module_handler(NVFP4InferenceConfig)
def _nvfp4_inference_linear_transform(
module: torch.nn.Linear, config: NVFP4InferenceConfig
):
"""Quantization handler for NVFP4InferenceConfig"""
if config.mm_config == NVFP4MMConfig.DYNAMIC:
assert is_sm_at_least_100(), (
"NVFP4 DYNAMIC mode is only supported on sm100+ machines"
)

weight = module.weight

if module.bias is not None and weight.dtype == torch.float32:
raise RuntimeError(
"Bias is not supported when module weight is in fp32 (out_dtype=Float32). "
"Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
)

quantized_weight = NVFP4Tensor.to_nvfp4(
weight,
mm_config=config.mm_config,
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals(
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
[
MXTensor,
NVFP4Tensor,
NVFP4MMConfig,
MXGemmKernelChoice,
_input_activation_quant_func_mxfp,
]
)
Loading
Loading