Skip to content

MoE refactor to use grouped_mm and scaled_grouped_mm #2600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

HDCharles
Copy link
Contributor

@HDCharles HDCharles commented Jul 25, 2025

Summary:

now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before.

  1. made AOQuantizable modules use grouped_mm + transpose, this allows us to compile without graph breaks for multi token, where there's a relevant kernel though it requires that we support transpose for the tensor subclass
  2. made grouped_mm dispatch to relevent dedicated kernels (bf16 and fp8dq only right now)
    2a) initially i was going to dispatch everything through grouped_mm, for non dedicated kernels it would dispatch to a linear decomposition, but this was incredibly slow so left the single token path as a separate path. For single token inference, the non-grouped_mm path is often faster (see results)
  3. fused 2 grouped_mm's together so now we only have 2 parameters/ops
  4. add handling for generic module swap to AOQuantizable modules + quantization using a MoEMapping, this significantly simplifies llama4_quant.py
  5. added new unit tests and refactored the unit tests
technique bsz 1 bsz 8
tok/s mem (GB) tok/s mem (GB)
---------------- -------- ---------- ------- ----------
bf16 (grouped_mm) 69.39 97.15 24.15 95.25
bf16 (decomposed) 76.84 95.28 16.72 95.25
fp8dq (scaled_grouped_mm) 100.74 72.76 42.83 72.73
fp8dq (decomposed) 45.55 72.76 5.99 72.74
noquant-linear 76.84 95.28 16.72 95.25
int8wo-linear 112.89 57.74 5.09 57.71
int4wo-linear 80.17 40.86 11.86 94.91
fp8wo-linear 5.49 72.76 1.43 72.73

while the grouped_mm implementation is a little slower than the decomposed one for bf16 single token, its significantly faster for bf16 multi token and fp8dq in general. Tests were done to compare the speed of the grouped_mm kernels for 3 vs 2 parameters (i.e. with/without the change #3 above) and found it went significantly faster. However the results were mixed for the decomposed runs. Compared to #2043 int8wo gets a 14% improvement which is the fastest single token technique, while other techniques were hit or miss.

full results: https://gist.github.com/HDCharles/4a529e12709a490777f53aed0c28bc33

Test Plan:

sh run.sh

python test/quantization/test_moe_quant.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jul 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2600

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures, 1 Cancelled Job

As of commit 2bb805b with merge base c086ade (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2025
@HDCharles HDCharles requested a review from alexsamardzic July 25, 2025 02:30
@alexsamardzic
Copy link
Collaborator

This one replaces #2325, right?

I'm struggling to run the run.sh script (i.e. generate.py script), keep getting "CUDA out of memory" errors on H100... Are you using PyTorch built from source and if not, which PyTorch package version do you have installed?

Would you mind finding mm_grouped.py file in your PyTorch installation, then changing can_use_triton_kernel() function there to just return False, and then re-trying? This will force eager (non-Triton) version of grouped MM kernel to be used even for max-autotune; namely, I suspect that the garbage output may not be from grouped MM Triton kernel itself, but maybe from max-autotuning the whole layer, and that would test it.

As a side note, it seems that MoEFeedForwardAOQuantizable should be imported for this and this.

@HDCharles
Copy link
Contributor Author

HDCharles commented Jul 26, 2025

This one replaces #2325, right?

I'm struggling to run the run.sh script (i.e. generate.py script), keep getting "CUDA out of memory" errors on H100... Are you using PyTorch built from source and if not, which PyTorch package version do you have installed?

Would you mind finding mm_grouped.py file in your PyTorch installation, then changing can_use_triton_kernel() function there to just return False, and then re-trying? This will force eager (non-Triton) version of grouped MM kernel to be used even for max-autotune; namely, I suspect that the garbage output may not be from grouped MM Triton kernel itself, but maybe from max-autotuning the whole layer, and that would test it.

As a side note, it seems that MoEFeedForwardAOQuantizable should be imported for this and this.

can you run it with batch_size 1?

i'll try the fix

yeah i haven't done the quantization dispatch stuff yet.

"""Configuration for applying quantization to MoE
Args:
`base_config`: normal AO Config
class DummyModule(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better solution is to make torchao APIs work on parameters. The current workaround is fine for prototype, but we'd want more proper support for non-prototype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that but nothing right now uses grouped_mm, if there's a use case for it, it'll be pretty trivial to add.

@@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
# T'(e) tokens for expert e


class MOEFeedForwardAOQuantizable(nn.Module):
class MoEFeedForwardAOQuantizable(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unlikely that people are going to swap their MoE module to AO's version. Can we just target torch._grouped_mm calls directly without requiring a module swap?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would it mean to "target" it specifically? If given model compiled, the compiled version of this operator will be used anyway, not sure what else torchao could do about it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that but nothing right now uses grouped_mm, if there's a use case for it, it'll be pretty trivial to add.

@alexsamardzic
Copy link
Collaborator

can you run it with batch_size 1?

Nope, with both batch_size 1 and 8, it runs out of memory.

@HDCharles HDCharles force-pushed the 092_BE_MoE branch 3 times, most recently from 74fa86e to ed6996a Compare August 1, 2025 17:42
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int8WeightOnlyConfig()
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, to support things generally i created a tensor subclass called FakeExtraDimTensor that takes traditional 2D tensor subclasses and handles the slicing and indexing ops needed for MoE. This enables stuff like ARM quantization where i wasn't able to procure a working dev env to actually natively support the 3D functionality. In practice its slower than the native 3D support but more general which may just not be worth it.


elif "fp8wo-base" in moe_quant:
config = MoEQuantConfig(Float8WeightOnlyConfig())
config.base_config = Float8WeightOnlyConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd personally focus on float8 dynamic quant for now, it's more important to get good perf + accuracy on a small set of techniques than to achieve broad coverage of techniques.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I have done that, I thought we'd be able to do dispatch grouped_mm -> linear -> quantized linear for hte rest of the tensor subclasses but it doesn't work so i've gone back to what it was initially and this PR mostly focuses on fp8dq and bf16

"--compile_mode",
type=str,
default="reduce-overhead",
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we expect a meaningful % of users to use max-autotune? if not, I'd say p1 and do it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic added it so i wanted to test and see the speedup, at least for these shapes it doesn't make a huge difference so its not hugely important.

# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
# expert_outs = self.cond_ffn(x, expert_indices)
# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
class ConditionalFeedForward(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just make this use torch._grouped_mm and remove the AOQuantizeable version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to do MoE quant without direct model modification, previously i manually modified the model but with the new swapping stuff its no longer needed.

@@ -225,41 +225,39 @@ def forward(
y = self.wo(y)
return y

class MoEFeedForward(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the same code as in the original mixtral model definition? if not, do we have a test testing for numerical equivalency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes its the original code. We don't have tests for numerical equivalence because i don't think that implementation is actually a source of truth. It can't actually do batchsize 8 due to the crazy dim join going on.

@HDCharles HDCharles force-pushed the 092_BE_MoE branch 3 times, most recently from db7ef2c to 22f837e Compare August 6, 2025 15:11
@HDCharles HDCharles changed the title wip MoE refactor MoE refactor to use grouped_mm and scaled_grouped_mm Aug 6, 2025
@HDCharles HDCharles added topic: new feature Use this tag if this PR adds a new feature enhancement New feature or request performance topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) labels Aug 6, 2025
@HDCharles HDCharles force-pushed the 092_BE_MoE branch 4 times, most recently from 8cd3640 to 4efb10f Compare August 6, 2025 17:44
@HDCharles HDCharles force-pushed the 092_BE_MoE branch 3 times, most recently from de51ba4 to fe93531 Compare August 7, 2025 16:28
)


@implements(grouped_mm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to make changes to Float8Tensor directly? since we are moving away from AQT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll do that in a new PR

@HDCharles HDCharles force-pushed the 092_BE_MoE branch 3 times, most recently from 32a4df5 to 50b60db Compare August 8, 2025 02:13
Summary:

now that the pytorch grouped_mm kernels don't require padding,
refactoring the moe implementation to use that rather than what was
there before.

DONE
-implement moe with grouped_mm [x]
-add handling for generic module swap to AOQuantizable (MoEMapping) [x]
-refactor MoEQuantConfig to swap generic modules [x]

TODO
-add dispatch from grouped_mm to linear decomposition of quantized
kernel
-compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile
linear decomposition
-compare linear decomposition vs new linear decomposition for quantized
kernels
-add scaled_group_gemm and fbgemm kernel (probably in a new PR)

ISSUE:
the autotuned grouped_mm kernels don't give the correct output, but then
work in eager and compile with reduce-overhead. why?

see new_run.log output, first 2 runs are fine, line 144 is nonsense

Test Plan:

sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@danielvegamyhre
Copy link
Contributor

now that the pytorch grouped_mm kernels don't require padding

To clarify, we no longer need to pad empty token groups for experts assigned 0 tokens, so the token group size isn't 0. However, padding to 16 byte alignment for slowest moving dim (stride 1) is still required, at least for fp8 scaled grouped mm. Does this align with your understanding as well, or am I missing something?

@@ -255,71 +265,72 @@ def main(

if moe_quant:
torch._dynamo.config.capture_dynamic_output_shape_ops = True
config = None
config = MoEQuantConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor point but the training config is called MoETrainingConfig so it would be more clear to call this MoEInferenceConfig, to remove any potential ambiguity present in the name MoEQuantConfig

@alexsamardzic
Copy link
Collaborator

To clarify, we no longer need to pad empty token groups for experts assigned 0 tokens, so the token group size isn't 0. However, padding to 16 byte alignment for slowest moving dim (stride 1) is still required, at least for fp8 scaled grouped mm. Does this align with your understanding as well, or am I missing something?

The limitations come from TMA, this test show what's possible (the most relevant part is "2d/3d", in particular row-major/row-major), and what you stated is correct. Note that for this particular case, besides 0 tokens, token group could consist of 1, 2, ... or any number of tokens, under condition that slowest moving dim is aligned to 16 bytes. Both eager and compiled version of grouped MM kernels have checks in place to detect wrong alignments.

Copy link

meta-cla bot commented Aug 8, 2025

Hi @HDCharles!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request performance topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants