-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass #10902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
b8ab496 to
e5ded5c
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
e5ded5c to
a3cb530
Compare
65afeae to
c22186b
Compare
vllm/compilation/reshapes.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these always the right ops to use? e.g. is there a torch.ops.aten.slice.default or a torch.ops.aten.slice_scatter.Tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen them so I am not sure - I just went off what I saw. The other overloads could be added easily if we ever see them in the graph
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
12e173e to
427bb9d
Compare
| # This could change in the future. | ||
| # We also don't pad when using torch.compile, | ||
| # as it breaks with dynamic shapes. | ||
| config = get_current_vllm_config().compilation_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this cached? It could be expensive each forward call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in eager mode this will get called on every forward pass, but it will only happen once when compiled. In eager mode there isn't really a better way that's still correct - the only way is to check the config context. I don't think this getter is significant but I haven't measured it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could pass in a allow_input_padding flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd have to pass that flag through the whole call stack though so I don't think it's worth it. I'll run a small model.
| # This could change in the future. | ||
| # We also don't pad when using torch.compile, | ||
| # as it breaks with dynamic shapes. | ||
| config = get_current_vllm_config().compilation_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could pass in a allow_input_padding flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm disabled?
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall but I had a few minor comments
- rename cutlass_fp8 test flag - rename noop pass - improve some comments Signed-off-by: luka <[email protected]>
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work! LGTM assuming we don't see any performance regression
|
Yep will post perf numbers once I have them, thanks! |
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902) Signed-off-by: luka <[email protected]>
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902) Signed-off-by: luka <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
|
nit: noop elimination for slice errors when end = -1. repro: similarly for slice_scatter. |
Great find - I didn't realize slice handles |
|
issue: #17078 Btw, pytorch supports noop elimination for view, slice, and slice_scatter now, which should be equivalent with |
|
Okay, sounds good, we can probably deprecate the pass, although it's nice to have an easy place to add noop transforms to unblock pattern matching in the short term before they fixes are upstreamed to torch. |
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902) Signed-off-by: luka <[email protected]>
This PR fixes the
fp8case, whencutlass_mmis not available. It contains the following fixes:fp8torch._scaled_mmin thetorch.compilecase, as branch specialization might not work correctly, and it makes fusion difficult.sliceandslice_scatterelimination, which is implemented in PyTorch but does not cover all cases. It renames theRedundantReshapesPasstoNoopEliminationPass.This PR is a pre-requisite PR to #10836, which enables
torch.compileon AMD and uses the non-cutlass-fp8 path.