Skip to content

Conversation

@ProExpertProg
Copy link
Collaborator

@ProExpertProg ProExpertProg commented Dec 4, 2024

This PR fixes the fp8 case, when cutlass_mm is not available. It contains the following fixes:

  • Removes the padding for fp8 torch._scaled_mm in the torch.compile case, as branch specialization might not work correctly, and it makes fusion difficult.
  • Implements redundant slice and slice_scatter elimination, which is implemented in PyTorch but does not cover all cases. It renames the RedundantReshapesPass to NoopEliminationPass.
  • Minor custom pass improvements.

This PR is a pre-requisite PR to #10836, which enables torch.compile on AMD and uses the non-cutlass-fp8 path.

@github-actions
Copy link

github-actions bot commented Dec 4, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch from b8ab496 to e5ded5c Compare December 4, 2024 20:18
@mergify
Copy link

mergify bot commented Feb 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 15, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch from e5ded5c to a3cb530 Compare February 25, 2025 20:52
@mergify mergify bot removed the needs-rebase label Feb 25, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch 2 times, most recently from 65afeae to c22186b Compare February 26, 2025 19:08
@ProExpertProg ProExpertProg changed the title Fix for the padding in the non-cutlass-fp8 case [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case Feb 26, 2025
Comment on lines 74 to 84
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these always the right ops to use? e.g. is there a torch.ops.aten.slice.default or a torch.ops.aten.slice_scatter.Tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen them so I am not sure - I just went off what I saw. The other overloads could be added easily if we ever see them in the graph

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 26, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch from 12e173e to 427bb9d Compare February 27, 2025 22:38
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cached? It could be expensive each forward call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in eager mode this will get called on every forward pass, but it will only happen once when compiled. In eager mode there isn't really a better way that's still correct - the only way is to check the config context. I don't think this getter is significant but I haven't measured it.

Copy link
Member

@tlrmchlsmth tlrmchlsmth Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass in a allow_input_padding flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm disabled?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd have to pass that flag through the whole call stack though so I don't think it's worth it. I'll run a small model.

# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
Copy link
Member

@tlrmchlsmth tlrmchlsmth Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass in a allow_input_padding flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm disabled?

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall but I had a few minor comments

- rename cutlass_fp8 test flag
- rename noop pass
- improve some comments

Signed-off-by: luka <[email protected]>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! LGTM assuming we don't see any performance regression

@ProExpertProg
Copy link
Collaborator Author

Yep will post perf numbers once I have them, thanks!

@ProExpertProg ProExpertProg changed the title [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass Feb 28, 2025
@mgoin mgoin merged commit bd56c98 into vllm-project:main Feb 28, 2025
40 checks passed
Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902)

Signed-off-by: luka <[email protected]>
tlrmchlsmth added a commit that referenced this pull request Mar 5, 2025
…-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)"

This reverts commit bd56c98.
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902)

Signed-off-by: luka <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
@BoyuanFeng
Copy link
Contributor

BoyuanFeng commented Apr 13, 2025

nit: noop elimination for slice errors when end = -1.

repro:

import torch

def dims_equivalent(dim, i_dim) -> bool:
    # Case 1 and 2
    if dim == i_dim or dim == -1:
        return True
    # Case 3
    return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

input = torch.randn((2,3,2))
dim_index, start, end = 0, 0, -1

input_shape = input.shape

i_dim = input_shape[dim_index]

# following https://github.com/vllm-project/vllm/blob/main/vllm/compilation/noop_elimination.py#L79
if start == 0 and dims_equivalent(end, i_dim):
    is_noop = True
else:
    is_noop = False

if is_noop:
    # input.shape: (2,3,2)
    # torch.ops.aten.slice.Tensor(input, dim_index, start, end).shape: (1,3,2)
    assert input.shape == torch.ops.aten.slice.Tensor(input, dim_index, start, end).shape

similarly for slice_scatter.

@ProExpertProg
Copy link
Collaborator Author

nit: noop elimination for slice errors when end = -1.

repro:

import torch

def dims_equivalent(dim, i_dim) -> bool:
    # Case 1 and 2
    if dim == i_dim or dim == -1:
        return True
    # Case 3
    return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

input = torch.randn((2,3,2))
dim_index, start, end = 0, 0, -1

input_shape = input.shape

i_dim = input_shape[dim_index]

# following https://github.com/vllm-project/vllm/blob/main/vllm/compilation/noop_elimination.py#L79
if start == 0 and dims_equivalent(end, i_dim):
    is_noop = True
else:
    is_noop = False

if is_noop:
    # input.shape: (2,3,2)
    # torch.ops.aten.slice.Tensor(input, dim_index, start, end).shape: (1,3,2)
    assert input.shape == torch.ops.aten.slice.Tensor(input, dim_index, start, end).shape

similarly for slice_scatter.

Great find - I didn't realize slice handles end differently. Could you file an issue for this please?

@BoyuanFeng
Copy link
Contributor

issue: #17078

Btw, pytorch supports noop elimination for view, slice, and slice_scatter now, which should be equivalent with noop_elimination.py.

@ProExpertProg
Copy link
Collaborator Author

Okay, sounds good, we can probably deprecate the pass, although it's nice to have an easy place to add noop transforms to unblock pattern matching in the short term before they fixes are upstreamed to torch.

shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902)

Signed-off-by: luka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants