-
Notifications
You must be signed in to change notification settings - Fork 317
metal lowbit kernels: optimized 2-bit, 3-bit and 4-bit shaders #1422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1422
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8fda452 with merge base 603d908 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
7226ac5
to
38df06e
Compare
38df06e
to
5a94188
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When adding optimized version of something we should have some sort of benchmarking numbers. Ideally I would like that to come from standalone benchmark but fo rnow you can report what you got from torchchat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i trust the ones that are resolved as will be fixed, will get fixed. rest looks ok
5a94188
to
84f9d73
Compare
84f9d73
to
8fda452
Compare
Adapts optimized 4-bit shader from PyTorch (MLX inspired) and adds similarly optimized 2-bit shader.
Adds optimized 3-bit shader
Restricts N to be multiple of 4 and adjusts tests accordingly.
Performance (tokens/sec via torchchat):
Llama 3.2 1B (llama3.2-1b-base):
1bit: 28.0688
2bit: 31.2422
3bit: 30.1294
4bit: 30.7905
5bit: 28.1504
6bit: 28.4321
7bit: 27.3991
Llama 3.1 8B (llama3.1-base):
1bit: 7.4459
2bit: 15.6508
3bit: 15.3086
4bit: 16.1268
5bit: 6.7308
6bit: 6.4887
7bit: 6.4537
Notice that the performance of all n-bit kernels is similar for the 1B-parameter model. This optimization is felt when running the 8B-parameter model. Notice the jump from 6-7 tok/sec to 15-16 tok/sec when comparing the non-optimized kernels to the optimized 2-bit, 3-bit and 4-bit.