From 4fe3daf6abf25ef3d811c5e2982dfca5b13fb5ab Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 18 Jun 2025 13:33:06 -0700 Subject: [PATCH] NVfp4 stack-info: PR: https://github.com/pytorch/ao/pull/2408, branch: drisspg/stack/78 --- test/prototype/mx_formats/test_mx_linear.py | 138 ++++- test/prototype/mx_formats/test_mx_tensor.py | 66 ++ torchao/prototype/mx_formats/__init__.py | 8 +- torchao/prototype/mx_formats/config.py | 6 +- torchao/prototype/mx_formats/mx_subclass.py | 73 ++- torchao/prototype/mx_formats/nvfp4_tensor.py | 617 +++++++++++++++++++ 6 files changed, 894 insertions(+), 14 deletions(-) create mode 100644 torchao/prototype/mx_formats/nvfp4_tensor.py diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bfb6742d14..0e39264742 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -9,6 +9,7 @@ import pytest import torch import torch.nn as nn +import torch.nn.functional as F from torchao.prototype.mx_formats.config import ( MXGemmKernelChoice, @@ -25,7 +26,11 @@ MXInferenceLinear, MXLinear, ) -from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig +from torchao.prototype.mx_formats.mx_subclass import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm @@ -404,6 +409,7 @@ def test_inference_print_str(): @skip_if_rocm( "ROCm float4 gemm require gfx950" ) # TODO(future): deploy gfx950 in ROCM CI +@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") def test_inference_subclass(elem_dtype, bias: bool, compile: bool): """ Smoke test for inference compile @@ -441,3 +447,133 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool): assert sqnr >= SQNR_THRESHOLD, ( f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize( + "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] +) +@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) +@torch.no_grad() +@skip_if_rocm("ROCm float4 gemm require gfx950") +def test_inference_subclass_nvfp4( + bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype +): + """ + Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16 + Tests both DYNAMIC and WEIGHT_ONLY mm_config modes + """ + # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs + if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") + + if bias and inpt_dtype == torch.float32: + pytest.xfail("Bias is not supported when module weight is in fp32") + + if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: + pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda") + m_mx = copy.deepcopy(m) + + config = NVFP4InferenceConfig(mm_config=mm_config) + quantize_(m_mx, config=config) + + if compile: + m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager") + + x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype) + y_ref = m(x) + y_mx = m_mx(x) + sqnr = compute_error(y_ref, y_mx) + + if mm_config == NVFP4MMConfig.WEIGHT_ONLY: + SQNR_THRESHOLD = 18.0 + else: + SQNR_THRESHOLD = 15.0 + + assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}" + assert sqnr >= SQNR_THRESHOLD, ( + f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.parametrize("use_gelu", [True, False]) +@pytest.mark.parametrize( + "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] +) +@pytest.mark.parametrize("compile", [False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) +@torch.no_grad() +@skip_if_rocm("ROCm float4 gemm require gfx950") +def test_nvfp4_matmul_with_amax( + use_gelu: bool, + mm_config: NVFP4MMConfig, + compile: bool, + bias: bool, + inpt_dtype: torch.dtype, +): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + per_tensor_amax_to_scale, + ) + + # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs + if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") + + if bias and inpt_dtype == torch.float32: + pytest.xfail("Bias is not supported when module weight is in fp32") + + if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: + pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + + m, k, n = 64, 256, 128 + + # Create activation tensor + if use_gelu: + x = torch.randn(m, k, dtype=inpt_dtype, device="cuda") + A = torch.nn.functional.gelu(x) + else: + A = torch.randn(m, k, dtype=inpt_dtype, device="cuda") + + B = torch.randn(n, k, dtype=inpt_dtype, device="cuda") + bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None + + # Compute reference + C_ref = F.linear(A, B, bias_tensor) + + a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A))) + b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B))) + A_nvfp4 = NVFP4Tensor.to_nvfp4( + A, + per_tensor_scale=a_scale, + mm_config=mm_config, + ) + B_nvfp4 = NVFP4Tensor.to_nvfp4( + B, + per_tensor_scale=b_scale, + mm_config=mm_config, + ) + + func = torch.compile(F.linear, fullgraph=True) if compile else F.linear + + C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor) + assert C_nvfp4.dtype == inpt_dtype, ( + f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}" + ) + + sqnr = compute_error(C_ref, C_nvfp4) + SQNR_THRESHOLD = 16.0 + assert sqnr >= SQNR_THRESHOLD, ( + f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}" + ) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 6dfd33f9c7..0490b0b1ee 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -14,6 +14,7 @@ from torchao.prototype.mx_formats.constants import ( DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, + F4_E2M1_MAX, SUPPORTED_ELEM_DTYPES, ) from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6 @@ -591,3 +592,68 @@ def to_f8(x): torch.testing.assert_close( data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0 ) + + +@pytest.mark.parametrize( + "dtype,shape,use_per_tensor_scale", + [ + (torch.bfloat16, (32, 64), False), + (torch.float32, (64, 128), False), + (torch.bfloat16, (128, 256), False), + (torch.bfloat16, (64, 128), True), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" +) +def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + per_tensor_amax_to_scale, + ) + + x = torch.randn(shape, dtype=dtype, device="cuda") + if use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(x)) + scale = per_tensor_amax_to_scale(tensor_amax) + else: + scale = None + + x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale) + x_reconstructed = x_nvfp4.to_dtype(dtype) + + def assert_sqnr_gt_threshold(orig, new, threshold): + sqnr = compute_error(orig, new) + if torch.all(torch.isnan(sqnr)): + # if both operands are full of zeroes, sqnr is nan and this is ok + # test for this explicitly + assert torch.all(orig == 0) and torch.all(new == 0) + else: + assert sqnr >= threshold + + reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX + max_abs = torch.amax( + torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1 + ).unsqueeze(-1) + + assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0) + assert_sqnr_gt_threshold(x, x_reconstructed, 8.0) + + assert x.shape == x_reconstructed.shape, ( + f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}" + ) + assert x.dtype == x_reconstructed.dtype, ( + f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}" + ) + + x_nvfp4_t = x_nvfp4.t() + x_reconstructed_t = x_nvfp4_t.to_dtype(dtype) + assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0) + + assert x.t().shape == x_reconstructed_t.shape, ( + f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}" + ) + assert x.t().dtype == x_reconstructed_t.dtype, ( + f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}" + ) diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index 7c1f0ace55..5947d616be 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -6,7 +6,11 @@ ) # Note: Prototype and subject to change -from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig +from torchao.prototype.mx_formats.mx_subclass import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) # import mx_linear here to register the quantize_ transform logic # ruff: noqa: I001 @@ -18,4 +22,6 @@ "MXLinearConfig", "MXLinearRecipeName", "MXFPInferenceConfig", + "NVFP4InferenceConfig", + "NVFP4MMConfig", ] diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index eb1b15228d..525bf21fc6 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" ) elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: - assert block_size == 32, ( - f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}" + assert block_size in [16, 32], ( + f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}" ) - valid_dtypes = [torch.float8_e4m3fn] + valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2] assert elem_dtype in valid_dtypes, ( f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" ) diff --git a/torchao/prototype/mx_formats/mx_subclass.py b/torchao/prototype/mx_formats/mx_subclass.py index 2173c97002..d1be8a04f4 100644 --- a/torchao/prototype/mx_formats/mx_subclass.py +++ b/torchao/prototype/mx_formats/mx_subclass.py @@ -10,7 +10,6 @@ import torch -import torchao from torchao.core.config import AOBaseConfig from torchao.prototype.mx_formats import ( MXGemmKernelChoice, @@ -20,13 +19,19 @@ _validate_gemm_kernel_choice, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor from torchao.quantization.quant_api import to_linear_activation_quantized from torchao.quantization.transform_module import ( register_quantize_module_handler, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_8, + is_sm_at_least_100, +) +# TODO The naming for these configs is a little weird, rename before moving to public API # Note: This API is extra prototype and will change in the future @dataclass class MXFPInferenceConfig(AOBaseConfig): @@ -63,16 +68,13 @@ class MXFPInferenceConfig(AOBaseConfig): block_size: int = 32 - # Dtypes for Input and Weights + # Dtypes for Input and Weights, supports Fp8 and Fp4 formats activation_dtype: torch.dtype = torch.float8_e4m3fn weight_dtype: torch.dtype = torch.float8_e4m3fn # Which kernel to run for mm gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS - # Set some magic perf settings - set_inductor_config: bool = False - def __post_init__(self): assert self.activation_dtype == self.weight_dtype, ( "For now - we only support matching input/weight dtypes." @@ -115,8 +117,6 @@ def _mx_inference_linear_transform( # TODO Sm120 has slightly more restrictive reqs # TODO handle AMD assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -151,7 +151,62 @@ def _mx_inference_linear_transform( return module +@dataclass +class NVFP4InferenceConfig(AOBaseConfig): + """ + NVIDIA FP4 (NVFP4) Inference Quantization Configuration + + This is a specialized configuration for NVIDIA's FP4 format. + All parameters are fixed in the NVFP4 implementation except mm_config: + - mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision) + - Data: float4_e2m1fn_x2 + - Scales: float8_e4m3fn + - Block size: 16 along the reduction dim + """ + + mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC + + def __post_init__(self): + # Validate PyTorch version + if not TORCH_VERSION_AT_LEAST_2_8: + raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later") + + +@register_quantize_module_handler(NVFP4InferenceConfig) +def _nvfp4_inference_linear_transform( + module: torch.nn.Linear, config: NVFP4InferenceConfig +): + """Quantization handler for NVFP4InferenceConfig""" + if config.mm_config == NVFP4MMConfig.DYNAMIC: + assert is_sm_at_least_100(), ( + "NVFP4 DYNAMIC mode is only supported on sm100+ machines" + ) + + weight = module.weight + + if module.bias is not None and weight.dtype == torch.float32: + raise RuntimeError( + "Bias is not supported when module weight is in fp32 (out_dtype=Float32). " + "Please use bfloat16 or float16 weights, or remove the bias from the linear layer." + ) + + quantized_weight = NVFP4Tensor.to_nvfp4( + weight, + mm_config=config.mm_config, + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals( - [MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp] + [ + MXTensor, + NVFP4Tensor, + NVFP4MMConfig, + MXGemmKernelChoice, + _input_activation_quant_func_mxfp, + ] ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py new file mode 100644 index 0000000000..ed1b5df1d0 --- /dev/null +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -0,0 +1,617 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Any, Callable, Dict, Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.prototype.mx_formats.constants import F4_E2M1_MAX, F8E4M3_MAX +from torchao.prototype.mx_formats.kernels import ( + f4_unpacked_to_f32, + f32_to_f4_unpacked, + pack_uint4, + unpack_uint4, +) +from torchao.prototype.mx_formats.mx_tensor import ( + tensor_size_fp4x2_to_hp, + tensor_size_hp_to_fp4x2, +) +from torchao.prototype.mx_formats.utils import to_blocked +from torchao.utils import fill_defaults + +E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny + +aten = torch.ops.aten + +NVFP4_OPS_TABLE: Dict[Any, Any] = {} + + +class NVFP4MMConfig(Enum): + DYNAMIC = "dynamic" + WEIGHT_ONLY = "weight_only" + + +def implements(aten_ops): + """Register aten ops to the NVFP4 op table""" + + def decorator(func): + for op in aten_ops: + NVFP4_OPS_TABLE[op] = func + return func + + return decorator + + +class NVFP4Tensor(torch.Tensor): + """NVIDIA FP4 (NVFP4) Tensor subclass. + + This implements the NVIDIA variant of MX FP4 format, which uses a specific + quantization algorithm for FP4 data with UE4M3 scales. + + Attributes: + _scale_e4m3: Blockwise scales in float8_e4m3fn format + _per_tensor_scale: Optional global per-tensor scale in float32 format + _data: Packed FP4 data (2 values per byte) + _block_size: Block size for quantization (fixed at 16) + _orig_dtype: Original tensor dtype before quantization + mm_config: Matrix multiplication configuration + """ + + _scale_e4m3: torch.Tensor + _per_tensor_scale: Optional[torch.Tensor] + _data: torch.Tensor + _block_size: int + _orig_dtype: torch.dtype + mm_config: NVFP4MMConfig + + def __new__( + cls, + blockwise_scales, + per_tensor_scale, + data_bits, + block_size, + orig_dtype, + mm_config=NVFP4MMConfig.DYNAMIC, + ): + # FP4 tensor size handling + new_size = data_bits.size() + new_size = tensor_size_fp4x2_to_hp( + new_size, + data_bits.is_contiguous(), + ) + + self = torch.Tensor._make_wrapper_subclass( + cls, + new_size, + dtype=orig_dtype, + device=data_bits.device, + requires_grad=False, + ) + + self._scale_e4m3 = blockwise_scales + self._per_tensor_scale = per_tensor_scale + self._data = data_bits + self._block_size = block_size + self._orig_dtype = orig_dtype + self.mm_config = mm_config + return self + + def __repr__(self): + return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}" + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Use NVFP4-specific ops table + if func in NVFP4_OPS_TABLE: + return NVFP4_OPS_TABLE[func](func, types, args, kwargs) + + raise NotImplementedError(f"{func} not implemented for NVFP4Tensor") + + @staticmethod + def to_nvfp4( + data_hp: torch.Tensor, + block_size: int = 16, + per_tensor_scale: Optional[torch.Tensor] = None, + mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC, + ): + """Convert high precision tensor to NVFP4 format. + + Args: + data_hp: High precision input tensor (bfloat16 or float32) + block_size: Block size for quantization (must be 16) + per_tensor_amax: Optional pre-computed absolute maximum for calibration. + If provided, uses per-tensor scaling. If None, uses block-wise scaling only. + + Returns: + NVFP4Tensor: Quantized tensor in NVFP4 format + """ + blockwise_scales, data_lp = nvfp4_quantize( + data_hp, block_size, per_tensor_scale + ) + return NVFP4Tensor( + blockwise_scales, + per_tensor_scale, + data_lp, + block_size, + data_hp.dtype, + mm_config, + ) + + def __tensor_flatten__(self): + ctx = { + "_block_size": self._block_size, + "_orig_dtype": self._orig_dtype, + "mm_config": self.mm_config, + } + tensor_list = ["_scale_e4m3", "_data"] + if self._per_tensor_scale is not None: + tensor_list.append("_per_tensor_scale") + return tensor_list, ctx + + def _apply_fn_to_data(self, fn: Callable): + """Applies a fn to all tensor components stored on this class""" + tensor_names, ctx = self.__tensor_flatten__() + new_tensors = {} + for name in tensor_names: + new_tensors[name] = fn(getattr(self, name)) + if "_per_tensor_scale" not in tensor_names: + new_tensors["_per_tensor_scale"] = None + return self.__class__.__tensor_unflatten__( + new_tensors, + ctx, + None, + None, + ) + + @staticmethod + def __tensor_unflatten__( + inner_tensors, + metadata, + outer_size, + outer_stride, + ): + return NVFP4Tensor( + inner_tensors["_scale_e4m3"], + inner_tensors.get("_per_tensor_scale", None), + inner_tensors["_data"], + metadata["_block_size"], + metadata["_orig_dtype"], + metadata["mm_config"], + ) + + # Do not force the NVFP4Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl + + def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: + """Convert NVFP4Tensor back to high precision dtype. + + Args: + target_dtype: Target dtype for dequantization (e.g., torch.float32, torch.bfloat16) + + Returns: + torch.Tensor: Dequantized tensor in the target dtype + """ + is_transposed = not self._data.is_contiguous() + if is_transposed: + M, K = self.shape[1], self.shape[0] + else: + M, K = self.shape[0], self.shape[1] + data = self._data.t() if is_transposed else self._data + data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8)) + data_f32 = f4_unpacked_to_f32(data_unpacked) + + data_f32 = data_f32.view(M, K // self._block_size, self._block_size) + scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1) + data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) + result = data_scaled.view(M, K).to(target_dtype) + + if is_transposed: + result = result.t() + + return result + + def get_hp_scales(self) -> torch.Tensor: + """Get the scales of the NVFP4Tensor in original dtype. + + Returns: + torch.Tensor: Scales of the NVFP4Tensor + """ + return ( + self._scale_e4m3.to(self._orig_dtype) + if not self._per_tensor_scale + else self._per_tensor_scale * self._scale_e4m3.to(self._orig_dtype) + ) + + @classmethod + def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: + """Check if two NVFP4Tensors have the same metadata. + + Args: + self: First NVFP4Tensor to compare + src: Second NVFP4Tensor to compare + + Returns: + bool: True if both tensors have identical metadata, False otherwise + """ + # Check per_tensor_scale equality + per_tensor_scale_equal = ( + self._per_tensor_scale is None and src._per_tensor_scale is None + ) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape) + + return ( + isinstance(self, NVFP4Tensor) + and isinstance(src, NVFP4Tensor) + and self._block_size == src._block_size + and self._orig_dtype == src._orig_dtype + and self._scale_e4m3.shape == src._scale_e4m3.shape + and per_tensor_scale_equal + and self._data.shape == src._data.shape + ) + + +@implements([aten.detach.default, aten.alias.default]) +def nvfp4_detach_alias(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(func) + ) + + +@implements([aten._to_copy.default]) +def nvfp4_to_copy(func, types, args, kwargs): + """Autocast + device movement""" + assert isinstance(args[0], NVFP4Tensor) + + # Handle dtype parameter + dtype = kwargs.pop("dtype", None) + if dtype is not None: + assert dtype in { + torch.float16, + torch.bfloat16, + torch.float32, + }, "Only support floating point conversion for autocast w/ NVFP4Tensor" + + # Handle device parameter + device = kwargs.pop("device", None) + if device is not None: + # Apply device change using _apply_fn_to_data + tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) + tensor = return_and_correct_aliasing(func, args, {}, tensor) + else: + tensor = args[0] + + if dtype is not None: + res = NVFP4Tensor( + tensor._scale_e4m3, + tensor._per_tensor_scale, + tensor._data, + tensor._block_size, + dtype, + tensor.mm_config, + ) + return res + + return tensor + + +@implements([aten.copy_.default]) +def nvfp4_copy_(func, types, args, kwargs): + self = args[0] + src = args[1] + if NVFP4Tensor._same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return self + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {self}, {src}" + ) + + +@implements([aten.clone.default]) +def nvfp4_clone(func, types, args, kwargs): + self = args[0] + memory_format = kwargs.get("memory_format", None) + + if memory_format is not None: + clone_fn = lambda x: x.clone(memory_format=memory_format) + else: + clone_fn = lambda x: x.clone() + + return self._apply_fn_to_data(clone_fn) + + +@implements([aten.slice.Tensor]) +def nvfp4_slice(func, types, args, kwargs): + x, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + + if step != 1: + raise ValueError("Only support aten.slice with step=1") + + assert x._data.is_contiguous(), "Only support contiguous data for now" + + M, K = x.shape[0], x.shape[1] + scale_shaped = x._scale_e4m3.view(M, K // x._block_size) + + if dim == 0: + # Slicing along the first dimension (rows) + sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step).flatten() + sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + elif dim == 1: + # Slicing along reduction dim - must align with block boundaries + if start is not None: + assert start % x._block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x._block_size}" + ) + + if end is not None: + assert end % x._block_size == 0, ( + f"End index {end} must be a multiple of block_size {x._block_size}" + ) + + sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + + # Calculate which scale blocks to keep + start_block = 0 if start is None else start // x._block_size + end_block = None if end is None else end // x._block_size + + # Slice the scale tensor accordingly + sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step) + else: + raise ValueError( + f"NVFP4Tensor only supports slicing along dimensions 0 and 1, got dim={dim}" + ) + + return NVFP4Tensor( + sliced_scale, + x._per_tensor_scale, # Unchanged per-tensor scale + sliced_data, + x._block_size, + x._orig_dtype, + x.mm_config, + ) + + +@implements([aten.t.default]) +def nvfp4_t(func, types, args, kwargs): + # For now, only transpose(input, 0, 1) is supported. + old = args[0] + new = NVFP4Tensor( + old._scale_e4m3, + old._per_tensor_scale, + old._data.t(), + old._block_size, + old._orig_dtype, + old.mm_config, + ) + return new + + +@implements([aten.view.default]) +def nvfp4_view_op(func, types, args, kwargs): + data = args[0]._data + new_size = args[1] + new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) + new_data = func(data, new_size, *args[2:], **kwargs) + return NVFP4Tensor( + args[0]._scale_e4m3, + args[0]._per_tensor_scale, + new_data, + args[0]._block_size, + args[0]._orig_dtype, + args[0].mm_config, + ) + + +def _addmm_nvfp4_dispatch( + a: NVFP4Tensor, b: NVFP4Tensor, aten_op, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Core implementation shared between nvfp4_mm, nvfp4_addmm, and nvfp4_linear. + The only difference is whether bias is None or not. + """ + assert a._data.is_contiguous() + assert b._data.t().is_contiguous() + assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" + assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" + + M, K = a.shape[0], a.shape[1] + N = b.shape[1] + + # Swizzle Dizzle + a_scale = a._scale_e4m3.view(M, K // a._block_size) + b_scale = b._scale_e4m3.view(N, K // b._block_size) + a_scale_blocked = to_blocked(a_scale) + b_scale_blocked = to_blocked(b_scale) + + # Merge double quant scales into 1 scale for Scale_In^D + if a._per_tensor_scale is not None: + assert b._per_tensor_scale is not None + scale_result = a._per_tensor_scale * b._per_tensor_scale + else: + assert b._per_tensor_scale is None and a._per_tensor_scale is None + scale_result = None + + # THIS IS A WORKAROUND: + # RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling + # When we have per-tensor scaling, we need to apply it before bias + # since bias is not quantized + should_add_bias_separately = (scale_result is not None) and (bias is not None) + + result = torch._scaled_mm( + a._data.view(torch.float4_e2m1fn_x2), + b._data.view(torch.float4_e2m1fn_x2), + a_scale_blocked.view(torch.float8_e4m3fn), + b_scale_blocked.view(torch.float8_e4m3fn), + bias=None if should_add_bias_separately else bias, + out_dtype=a._orig_dtype, + # scale_result=scale_result, # Not supported yet + ) + + if scale_result is not None: + result = result * scale_result.to(a._orig_dtype) + + # Add bias after scaling if needed + if should_add_bias_separately: + result = result + bias + + return result + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def nvfp4_linear(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + if not isinstance(weight_tensor, NVFP4Tensor): + raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor") + + config = weight_tensor.mm_config + + if config == NVFP4MMConfig.WEIGHT_ONLY: + weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype) + return torch.nn.functional.linear(input_tensor, weight_dequant, bias) + else: + input_quant = NVFP4Tensor.to_nvfp4(input_tensor, mm_config=config) + return _addmm_nvfp4_dispatch(input_quant, weight_tensor, func, bias=bias) + + +@implements([aten.mm.default, aten.matmul.default]) +def nvfp4_mm(func, types, args, kwargs): + input_tensor, weight_tensor = args[0], args[1] + + if not isinstance(weight_tensor, NVFP4Tensor): + raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor") + + config = weight_tensor.mm_config + + if config == NVFP4MMConfig.WEIGHT_ONLY: + weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype) + if isinstance(input_tensor, NVFP4Tensor): + input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype) + return func(input_dequant, weight_dequant) + else: + return func(input_tensor, weight_dequant) + else: + if not isinstance(input_tensor, NVFP4Tensor): + input_tensor = NVFP4Tensor.to_nvfp4(input_tensor, mm_config=config) + return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func) + + +@implements([aten.addmm.default]) +def nvfp4_addmm(func, types, args, kwargs): + bias, input_tensor, weight_tensor = args[0], args[1], args[2] + + if not isinstance(weight_tensor, NVFP4Tensor): + raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor") + + config = weight_tensor.mm_config + + if config == NVFP4MMConfig.WEIGHT_ONLY: + weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype) + if isinstance(input_tensor, NVFP4Tensor): + input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype) + return torch.addmm(bias, input_dequant, weight_dequant) + else: + return torch.addmm(bias, input_tensor, weight_dequant) + else: + if not isinstance(input_tensor, NVFP4Tensor): + input_tensor = NVFP4Tensor.to_nvfp4(input_tensor, mm_config=config) + return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func, bias=bias) + + +def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor: + """Convert per-tensor amax to per-tensor scale. + Used to scale fp32 scales down to fp8 scales + + Args: + amax: Per-tensor amax tensor + + Returns: + torch.Tensor: Per-tensor scale tensor + """ + return torch.clamp(amax / F8E4M3_MAX, min=E4M3_EPS, max=F8E4M3_MAX).to( + torch.float32 + ) + + +def nvfp4_quantize( + data_hp: torch.Tensor, + block_size: int = 16, + per_tensor_scale: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """NVIDIA FP4 quantization with UE4M3 scales. + + Implements the NVIDIA algorithm for quantizing tensors to FP4 format + with unsigned E4M3 (UE4M3) scales. + + Args: + data_hp: High precision input tensor (bfloat16 or float32) + block_size: Block size for quantization (must be 16) + per_tensor_amax: Optional pre-computed absolute maximum for calibration. + If provided, uses per-tensor scaling. If None, uses block-wise scaling only. + + Returns: + tuple: A tuple containing: + - total_scale_fp8: Blockwise scales in float8_e4m3fn format + - per_tensor_scale: Global per-tensor scale if per_tensor_amax provided, else None + - data_lp: Packed FP4 data (2 values per byte) + + Raises: + AssertionError: If input dtype is not supported, tensor size is not + divisible by block_size, tensor is not contiguous, or block_size != 16 + """ + assert data_hp.dtype in (torch.bfloat16, torch.float), ( + f"{data_hp.dtype} not supported" + ) + assert data_hp.numel() % block_size == 0, "unsupported" + assert data_hp.is_contiguous(), "unsupported" + assert block_size == 16, "NVFP4 requires block_size=16" + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(orig_shape[0], -1, block_size) + + max_abs = torch.amax(torch.abs(data_hp), dim=-1) + # These scales are currently in fp32, we are going to `quantize` them to e4m3 + block_scale = max_abs / F4_E2M1_MAX + + out_scales = None + if per_tensor_scale is None: + # We are doing single level scaling + block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to( + torch.float8_e4m3fn + ) + block_scale_fp32 = block_scale_fp8.to(torch.float32) + data_scaled = data_hp / block_scale_fp32.unsqueeze(-1) + out_scales = block_scale_fp8 + else: + # We are doing two level scaling, + # This will likely be calibrated but + # we want the per_tensor_scale ~= amax of the block_scale_fp32 + block_scale_fp32 = block_scale.to(torch.float32) + # Quantize the blockwise scales w/ the per_tensor_scale + scaled_block_scales = block_scale_fp32 / per_tensor_scale + scaled_block_scales_fp8 = torch.clamp( + scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX + ).to(torch.float8_e4m3fn) + scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) + # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale + # To apply to data + total_scale = per_tensor_scale * scaled_block_scales_fp32 + data_scaled = data_hp / total_scale.unsqueeze(-1) + out_scales = scaled_block_scales_fp8 + + data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) + data_scaled = data_scaled.view(orig_shape) + data_lp = f32_to_f4_unpacked(data_scaled.float()) + # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' + # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) + data_lp = pack_uint4(data_lp) + return out_scales, data_lp