Skip to content

moe quant with dedicated kernels [wip] #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

HDCharles
Copy link
Contributor

@HDCharles HDCharles commented Jun 6, 2025

Summary:

current status:
both kernels are working. The padding is a significant issue with compile for the pytorch kernel while the fbgemm kernel doesn't seem compatible with compile. Hopefully this can be handled using the changes mentioned below to avoid the data dependent padding.

todo:
test the no-padding compilable pytorch kernel

change base integration to grouped_gemm (another PR)

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jun 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2325

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit 186708f with merge base f0f1f6c (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 6, 2025
isinstance(self.w1, torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor) and
isinstance(self.w1.original_weight_tensor._layout, torchao.dtypes.floatx.float8_layout.Float8Layout)
):
final_out = fp8_dq_moe_op(x, self.w1, self.w2, self.w3, expert_indices, scores)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to call this op without modifying the source model?

is there a gropup_mm for bfloat16 that we can overwrite and dispatch to scaled_grouped_mmm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is _grouped_mm in PyTorch core that does that.

Copy link
Contributor Author

@HDCharles HDCharles Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a better integration point but i'm not sure i'll be able to complete that before i have to head out on leave.

also i'd probably make that a separate PR instead of combining everything into one since that would be a significant change to the base moe integration.

@alexsamardzic
Copy link
Collaborator

PR to hopefully remove need for padding groups is here: pytorch/pytorch#155466.

alignment = 16
if _torchtitan_available:
num_ranks = 1
padded_indices, m_offsets = torchtitan_pad(num_tokens_per_expert, alignment, num_ranks)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heads up, soon we won't need padding once #155466 lands

input_fp8[valid_values] = q_input_data[token_shuffle]
input_scale[valid_values] = q_input_scale[token_shuffle] if q_input_scale.numel()>1 else q_input_scale

if use_fbgemm_kernel:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have fbgemm-like kernels available via autotuning in torch.compile, thanks to #155138, do you think we still need separate fbgemm path?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that require max-autotune?

Summary:

extending the torchao moe support to have more performant kernels. This
PR supports both scaled_grouped_mm and fbgemm's grouped_gemm_fp8_rowwise
though it seems like grouped_gemm_fp8_rowwise is a bit buggy (need to
make a clear repro)

todo: run benchmarks, debug fbgemm kernel, unit tests

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Jun 18, 2025

PR pytorch/pytorch#155466, that makes it possible to avoid padding, is merged. Here is a quick patch to remove padding (note that it also disables FBGEMM altogether, so _scaled_grouped_mm implementation from PyTorch core is used):

096_fuse_moeb-diff.txt

I'm not completely sure that scale tensor shape adjustment I made here are correct, but in any case this patch, together with latest PyTorch used, will make all the tests in test_moe_quant.py pass.

@ngimel
Copy link

ngimel commented Jun 18, 2025

@alexsamardzic we still would need padding for backward where K could possibly become 0?

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Jun 19, 2025

@alexsamardzic we still would need padding for backward where K could possibly become 0?

This PR is not concerned about backward, but I would say @danielvegamyhre is touching on it: #2405. In any case, you have a point, here is a diff for _grouped_mm tests in PyTorch to demonstrate the issue:

diff
diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py
index 4e64c807425..96667a79440 100644
--- a/test/test_matmul_cuda.py
+++ b/test/test_matmul_cuda.py
@@ -354,15 +354,15 @@ class TestMatmulCuda(TestCase):
     @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
     @xfailIfSM100OrLater
     @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
-    @parametrize("strided", [False, True])
-    @parametrize("a_row_major", [False, True])
-    @parametrize("b_row_major", [False, True])
-    @parametrize("use_torch_compile", [False, True])
+    @parametrize("strided", [True])
+    @parametrize("a_row_major", [True])
+    @parametrize("b_row_major", [True])
+    @parametrize("use_torch_compile", [True, False])
     def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, use_torch_compile):
         device = "cuda"
         dtype = torch.bfloat16
         s_int = int(strided)
-        m, n, k, n_groups = 16, 32, 64, 4
+        m, n, k, n_groups = 3, 32, 64, 5
         if a_row_major:
             a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k]
         else:
@@ -388,6 +388,7 @@ class TestMatmulCuda(TestCase):
             a.grad = None
             b.grad = None
             offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
+            offs = torch.tensor([0, 1, 6, 6, 15], device="cuda", dtype=torch.int32)
             if check_zero_size:
                 offs[0] = offs[1]

If offsets changed say to [1, 3, 5, 6, 15], so that there are no zero sizes, it would work fine. I'm going to see is it possible to further refine these checks, in order to make it work for this case.

Edit: see here.

@alexsamardzic
Copy link
Collaborator

Here is slightly changed diff: 096_fuse_moeb-diff.txt. To be applied after PR rebased on latest main.

Some end-to-end performance numbers for Mixtral model, for current version of the PR:

First, this patch is to be applied to enforce auto-tuning for all cases:

diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py
index 11a53043..10da20f7 100644
--- a/torchao/_models/mixtral-moe/generate.py
+++ b/torchao/_models/mixtral-moe/generate.py
@@ -337,10 +337,10 @@ def main(
 
         if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant):
             decode_one_token = torch.compile(
-                decode_one_token, mode="reduce-overhead", fullgraph=True
+                decode_one_token, mode="max-autotune"
             )
         else:
-            decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
+            decode_one_token = torch.compile(decode_one_token, mode="max-autotune")
 
         if args.compile_prefill:
             prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

For each run below, --checkpoint_path=.../checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1/model.pth should be added to command line, skipped for brevity.

First, a variant that goes through this branch of the model forward function:

$ python generate.py --compile --moe_quant fp8dq

Average tokens/sec: 10.23
Memory used: 51.21 GB
model size: 48.37

Then a variant that goes through this branch:

$ python generate.py --compile --moe_quant fp8dq-base

Average tokens/sec: 56.34
Memory used: 59.14 GB
model size: 48.37

Finally, for a variant that will utilize auto-tuned _scaled_grouped_mm i.e. go through this branch of the model, one could skip the first branch by adding if False and... here, and then:

$ python generate.py --compile --moe_quant fp8dq-base

Average tokens/sec: 101.24
Memory used: 59.14 GB
model size: 48.37

If auto-tuning for _scaled_grouped_mm disabled (e.g. by simply removing the meta registration), the tokens/sec is practically the same, and that is kind of expected as the batches are of very small size.

Again, this all could be further improved, leaving it at that for now.

@ngimel
Copy link

ngimel commented Jun 25, 2025

@alexsamardzic so pytorch now with removed paddiing restrictions is strictly better than fbgemm?

@alexsamardzic
Copy link
Collaborator

@alexsamardzic so pytorch now with removed paddiing restrictions is strictly better than fbgemm?

The FBGEMM kernel is not included in the results above. To have it activated, on top of changes mentioned for the last run (that was about using _scaled_grouped_mm), the False here should be changed to True. After these changes:

$ python generate.py --compile --moe_quant fp8dq-base

Average tokens/sec: 23.51
Memory used: 59.14 GB
model size: 48.37

Plus, the output is garbage. However, regarding the performance, note this, i.e. the compilation is at the moment disabled around calls to FBGEMM kernel; and if enabled it would error out. I'm not sure @HDCharles would be interested in working further on that branch, but IMO both FBGEMM and PyTorch Triton kernels should have similar performance, so FBGEMM kernel usage may be safely skipped.

Copy link

meta-cla bot commented Aug 8, 2025

Hi @HDCharles!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants