-
Notifications
You must be signed in to change notification settings - Fork 317
MoE refactor to use grouped_mm and scaled_grouped_mm #2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2600
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New Failures, 1 Cancelled JobAs of commit 2bb805b with merge base c086ade ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This one replaces #2325, right? I'm struggling to run the Would you mind finding As a side note, it seems that |
can you run it with batch_size 1? i'll try the fix yeah i haven't done the quantization dispatch stuff yet. |
torchao/prototype/moe_quant/utils.py
Outdated
"""Configuration for applying quantization to MoE | ||
Args: | ||
`base_config`: normal AO Config | ||
class DummyModule(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better solution is to make torchao APIs work on parameters. The current workaround is fine for prototype, but we'd want more proper support for non-prototype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that but nothing right now uses grouped_mm, if there's a use case for it, it'll be pretty trivial to add.
torchao/_models/mixtral-moe/model.py
Outdated
@@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |||
# T'(e) tokens for expert e | |||
|
|||
|
|||
class MOEFeedForwardAOQuantizable(nn.Module): | |||
class MoEFeedForwardAOQuantizable(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems unlikely that people are going to swap their MoE module to AO's version. Can we just target torch._grouped_mm
calls directly without requiring a module swap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would it mean to "target" it specifically? If given model compiled, the compiled version of this operator will be used anyway, not sure what else torchao could do about it...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that but nothing right now uses grouped_mm, if there's a use case for it, it'll be pretty trivial to add.
Nope, with both batch_size 1 and 8, it runs out of memory. |
74fa86e
to
ed6996a
Compare
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
) | ||
config.base_config = Int8WeightOnlyConfig() | ||
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially, to support things generally i created a tensor subclass called FakeExtraDimTensor that takes traditional 2D tensor subclasses and handles the slicing and indexing ops needed for MoE. This enables stuff like ARM quantization where i wasn't able to procure a working dev env to actually natively support the 3D functionality. In practice its slower than the native 3D support but more general which may just not be worth it.
|
||
elif "fp8wo-base" in moe_quant: | ||
config = MoEQuantConfig(Float8WeightOnlyConfig()) | ||
config.base_config = Float8WeightOnlyConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd personally focus on float8 dynamic quant for now, it's more important to get good perf + accuracy on a small set of techniques than to achieve broad coverage of techniques.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I have done that, I thought we'd be able to do dispatch grouped_mm -> linear -> quantized linear for hte rest of the tensor subclasses but it doesn't work so i've gone back to what it was initially and this PR mostly focuses on fp8dq and bf16
"--compile_mode", | ||
type=str, | ||
default="reduce-overhead", | ||
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we expect a meaningful % of users to use max-autotune? if not, I'd say p1 and do it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic added it so i wanted to test and see the speedup, at least for these shapes it doesn't make a huge difference so its not hugely important.
# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] | ||
# expert_outs = self.cond_ffn(x, expert_indices) | ||
# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) | ||
class ConditionalFeedForward(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just make this use torch._grouped_mm
and remove the AOQuantizeable version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hoping to do MoE quant without direct model modification, previously i manually modified the model but with the new swapping stuff its no longer needed.
@@ -225,41 +225,39 @@ def forward( | |||
y = self.wo(y) | |||
return y | |||
|
|||
class MoEFeedForward(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same code as in the original mixtral model definition? if not, do we have a test testing for numerical equivalency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes its the original code. We don't have tests for numerical equivalence because i don't think that implementation is actually a source of truth. It can't actually do batchsize 8 due to the crazy dim join going on.
db7ef2c
to
22f837e
Compare
8cd3640
to
4efb10f
Compare
de51ba4
to
fe93531
Compare
) | ||
|
||
|
||
@implements(grouped_mm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to make changes to Float8Tensor directly? since we are moving away from AQT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'll do that in a new PR
32a4df5
to
50b60db
Compare
Summary: now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before. DONE -implement moe with grouped_mm [x] -add handling for generic module swap to AOQuantizable (MoEMapping) [x] -refactor MoEQuantConfig to swap generic modules [x] TODO -add dispatch from grouped_mm to linear decomposition of quantized kernel -compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile linear decomposition -compare linear decomposition vs new linear decomposition for quantized kernels -add scaled_group_gemm and fbgemm kernel (probably in a new PR) ISSUE: the autotuned grouped_mm kernels don't give the correct output, but then work in eager and compile with reduce-overhead. why? see new_run.log output, first 2 runs are fine, line 144 is nonsense Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
To clarify, we no longer need to pad empty token groups for experts assigned 0 tokens, so the token group size isn't 0. However, padding to 16 byte alignment for slowest moving dim (stride 1) is still required, at least for fp8 scaled grouped mm. Does this align with your understanding as well, or am I missing something? |
@@ -255,71 +265,72 @@ def main( | |||
|
|||
if moe_quant: | |||
torch._dynamo.config.capture_dynamic_output_shape_ops = True | |||
config = None | |||
config = MoEQuantConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor point but the training config is called MoETrainingConfig
so it would be more clear to call this MoEInferenceConfig
, to remove any potential ambiguity present in the name MoEQuantConfig
The limitations come from TMA, this test show what's possible (the most relevant part is "2d/3d", in particular row-major/row-major), and what you stated is correct. Note that for this particular case, besides 0 tokens, token group could consist of 1, 2, ... or any number of tokens, under condition that slowest moving dim is aligned to 16 bytes. Both eager and compiled version of grouped MM kernels have checks in place to detect wrong alignments. |
Hi @HDCharles! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Summary:
now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before.
2a) initially i was going to dispatch everything through grouped_mm, for non dedicated kernels it would dispatch to a linear decomposition, but this was incredibly slow so left the single token path as a separate path. For single token inference, the non-grouped_mm path is often faster (see results)
while the grouped_mm implementation is a little slower than the decomposed one for bf16 single token, its significantly faster for bf16 multi token and fp8dq in general. Tests were done to compare the speed of the grouped_mm kernels for 3 vs 2 parameters (i.e. with/without the change #3 above) and found it went significantly faster. However the results were mixed for the decomposed runs. Compared to #2043 int8wo gets a 14% improvement which is the fastest single token technique, while other techniques were hit or miss.
full results: https://gist.github.com/HDCharles/4a529e12709a490777f53aed0c28bc33
Test Plan:
sh run.sh
python test/quantization/test_moe_quant.py
Reviewers:
Subscribers:
Tasks:
Tags: