-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🐛 Describe the bug
In the torchao float8 MoE training prototype, we have an autograd func that performs a differentiable scaled grouped mm, with dynamic fp8 rowwise quantization on inputs before every GEMM in forward + backward.
Recently, torch.compile
started throwing an error when compiling this, when it wasn't ~1-2 months ago. It seems to be related to the addition of autotuner support for scaled grouped mm.
Specifically, the error is asserting the "B" tensor is one of the scaled grouped GEMMs is not in column memory major layout. However, my code has assertions before every torch._scaled_grouped_mm call validating that the "A" tensor is row major and the "B" tensor is column major - and these assertions do not trigger. The code works as expected in eager mode.
Error:
torch._inductor.exc.InductorError: LoweringException: RuntimeError: mat_b must be col_major, got stride [67108864, 8192, 1]
target: aten._scaled_grouped_mm.default
Full stack trace: https://www.internalfb.com/phabricator/paste/view/P1844809161
Repro
- Checkout PR in torchao (skip if PR is landed by the time you see this): [float8 moe training] make using triton kernels for per-group scaling configurable ao#2405
- Run
python benchmark_scaled_grouped_mm.py --compile
from~/ao/torchao/prototype/moe_training/benchmarks
Versions
torch: 2.8.0.dev20250611+cu126
torchao: 0.11.0+gitcdced21f
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov