diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index a9324fe393..21ac2a297a 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -148,7 +148,7 @@ def run( ) assert y_d0.dtype == torch.float8_e4m3fn - assert s_d0.dtype == torch.uint8 + assert s_d0.dtype == torch.float8_e8m0fnu bytes_r = x.numel() * bytes_per_el_bf16 bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) @@ -166,7 +166,7 @@ def run( ) assert y_d1.dtype == torch.float8_e4m3fn - assert s_d1.dtype == torch.uint8 + assert s_d1.dtype == torch.float8_e8m0fnu bytes_r = x.numel() * bytes_per_el_bf16 bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 580bff2172..bce0b3913c 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -459,9 +459,7 @@ def test_fp6_e3m2_pack_unpack(): ) @pytest.mark.parametrize("M", (256, 2048)) @pytest.mark.parametrize("K", (256, 2048)) -# @pytest.mark.parametrize("M", (256,)) -# @pytest.mark.parametrize("K", (256,)) -def test_triton_mxfp8_dim1(M, K): +def test_triton_mxfp8_dim1_randn(M, K): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index c3c987baf9..00a76c47c3 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1126,7 +1126,9 @@ def _triton_calculate_scale(x, axis): scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) - # TODO(future PR): add NaN handling here + # TODO(future PR): add NaN handling here, + # https://github.com/pytorch/pytorch/pull/100572 will likely be useful to + # get proper NaN propagation working # Calculate the scale in floating point. scale_fp = (scale_e8m0_biased.to(tl.int32) << fp32_mbits).to( diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 4a9ff498d5..ebd45970e4 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -236,7 +236,6 @@ def to_mx( # Calculate the scale for different modes max_abs_int32 = (max_abs + eps).view(hp_int_dtype) extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias - extracted_pow2 = extracted_pow2.to(data_hp.dtype) if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): scale_e8m0_unbiased = extracted_pow2 - target_max_pow2