-
Notifications
You must be signed in to change notification settings - Fork 317
[sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity #2242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…sity Summary: We have this gemm already in torchao, but for weight sparsity. For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode. Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2242
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 4 PendingAs of commit e17ebfd with merge base f0f976c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, left a couple minor comments
using ElementOut = cutlass::bfloat16_t; | ||
using ElementAccumulator = float; | ||
|
||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how was this tile shape selected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the default I copied over, planning on adding some tuning in a subsequent PR for more perf.
cutlass::arch::OpClassSparseTensorOp, | ||
ElementA, | ||
cutlass::layout::RowMajor, | ||
32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would help with readability to define give these constant args variable names IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point, will address these nits when I add in the tile config tuning just want to get unblocked for now.
device_guard.emplace(tensor_a.device()); | ||
} | ||
|
||
using K = SparseRowwiseKernel<cutlass::float_e4m3_t>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: more descriptive variable name would be helpful
Summary:
We have this gemm already in torchao, but for weight sparsity, which assumes the weights are in row-major formats and are sparse
For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: