Skip to content

[inductor] tune_scaled_grouped_mm fails with memory layout assertion, despite memory layout assertions prior to op call passing #156325

@danielvegamyhre

Description

@danielvegamyhre

🐛 Describe the bug

In the torchao float8 MoE training prototype, we have an autograd func that performs a differentiable scaled grouped mm, with dynamic fp8 rowwise quantization on inputs before every GEMM in forward + backward.

Recently, torch.compile started throwing an error when compiling this, when it wasn't ~1-2 months ago. It seems to be related to the addition of autotuner support for scaled grouped mm.

Specifically, the error is asserting the "B" tensor is one of the scaled grouped GEMMs is not in column memory major layout. However, my code has assertions before every torch._scaled_grouped_mm call validating that the "A" tensor is row major and the "B" tensor is column major - and these assertions do not trigger. The code works as expected in eager mode.

Error:

torch._inductor.exc.InductorError: LoweringException: RuntimeError: mat_b must be col_major, got stride [67108864, 8192, 1]
  target: aten._scaled_grouped_mm.default

Full stack trace: https://www.internalfb.com/phabricator/paste/view/P1844809161

Repro

Versions

torch: 2.8.0.dev20250611+cu126
torchao: 0.11.0+gitcdced21f

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions