Skip to content

Commit 06ac6dd

Browse files
committed
mx: small speedup with dim0 cast
Summary: Removes the unnecessary cast to bfloat16 in the MX dim0 casting code. This is a 2.6% speedup on 16k by 16k shape: https://www.internalfb.com/phabricator/paste/view/P1769373804 Note: this PR also includes a couple of cleanups around e8m0 dtype and NaN handling, I found them while coding this PR. Leaving them together instead of separate PR since they are all safe. Test Plan: ```bash (pytorch) [[email protected] ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384 M 16384 K 16384 BLOCK_SIZE 32 GPU: NVIDIA B200 torch version: 2.8.0a0+git25309a1 triton version: 3.3.0 mode: dim0_mx time_us 152.90741052631583 mem_bw_gbps 5321.488168553876 (pytorch) [[email protected] ~/local/ao (20250321_mx_dim1_triton_kernel)]$ (pytorch) [[email protected] ~/local/ao (20250321_mx_dim1_triton_kernel)]$ (pytorch) [[email protected] ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384 M 16384 K 16384 BLOCK_SIZE 32 GPU: NVIDIA B200 torch version: 2.8.0a0+git25309a1 triton version: 3.3.0 mode: dim0_mx time_us 149.03950980392162 mem_bw_gbps 5459.5924065404415 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 47fb1df ghstack-comment-id: 2762318741 Pull Request resolved: #1980
1 parent 1c5627c commit 06ac6dd

File tree

4 files changed

+6
-7
lines changed

4 files changed

+6
-7
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def run(
148148
)
149149

150150
assert y_d0.dtype == torch.float8_e4m3fn
151-
assert s_d0.dtype == torch.uint8
151+
assert s_d0.dtype == torch.float8_e8m0fnu
152152
bytes_r = x.numel() * bytes_per_el_bf16
153153
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
154154
bps = (bytes_r + bytes_w) / (time_us / 1e6)
@@ -166,7 +166,7 @@ def run(
166166
)
167167

168168
assert y_d1.dtype == torch.float8_e4m3fn
169-
assert s_d1.dtype == torch.uint8
169+
assert s_d1.dtype == torch.float8_e8m0fnu
170170
bytes_r = x.numel() * bytes_per_el_bf16
171171
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
172172
bps = (bytes_r + bytes_w) / (time_us / 1e6)

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,7 @@ def test_fp6_e3m2_pack_unpack():
459459
)
460460
@pytest.mark.parametrize("M", (256, 2048))
461461
@pytest.mark.parametrize("K", (256, 2048))
462-
# @pytest.mark.parametrize("M", (256,))
463-
# @pytest.mark.parametrize("K", (256,))
464-
def test_triton_mxfp8_dim1(M, K):
462+
def test_triton_mxfp8_dim1_randn(M, K):
465463
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
466464
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
467465
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,9 @@ def _triton_calculate_scale(x, axis):
11261126
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
11271127
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
11281128

1129-
# TODO(future PR): add NaN handling here
1129+
# TODO(future PR): add NaN handling here,
1130+
# https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
1131+
# get proper NaN propagation working
11301132

11311133
# Calculate the scale in floating point.
11321134
scale_fp = (scale_e8m0_biased.to(tl.int32) << fp32_mbits).to(

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def to_mx(
236236
# Calculate the scale for different modes
237237
max_abs_int32 = (max_abs + eps).view(hp_int_dtype)
238238
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias
239-
extracted_pow2 = extracted_pow2.to(data_hp.dtype)
240239

241240
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
242241
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2

0 commit comments

Comments
 (0)