Skip to content

[sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 22, 2025

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented May 22, 2025

Summary:

We have this gemm already in torchao, but for weight sparsity, which assumes the weights are in row-major formats and are sparse

For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode.

Test Plan:

pytest test/sparsity/test_activation24.py

Reviewers:

Subscribers:

Tasks:

Tags:

jcaip added 2 commits May 21, 2025 16:46
…sity

Summary:

We have this gemm already in torchao, but for weight sparsity.

For activation sparsity, we need the weights to be stored in
column-major format to allow for us to use the selective weight loading
kernel for decode.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented May 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2242

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 4 Pending

As of commit e17ebfd with merge base f0f976c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 22, 2025
@jcaip jcaip added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label May 22, 2025
Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, left a couple minor comments

using ElementOut = cutlass::bfloat16_t;
using ElementAccumulator = float;

using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this tile shape selected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the default I copied over, planning on adding some tuning in a subsequent PR for more perf.

cutlass::arch::OpClassSparseTensorOp,
ElementA,
cutlass::layout::RowMajor,
32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would help with readability to define give these constant args variable names IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point, will address these nits when I add in the tile config tuning just want to get unblocked for now.

device_guard.emplace(tensor_a.device());
}

using K = SparseRowwiseKernel<cutlass::float_e4m3_t>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: more descriptive variable name would be helpful

@jcaip jcaip merged commit 4c6188f into main May 22, 2025
35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants