-
Notifications
You must be signed in to change notification settings - Fork 317
Support mixed MX element dtype in mx_mm
function and MXLinear
.
#1667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support mixed MX element dtype in mx_mm
function and MXLinear
.
#1667
Conversation
Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1667
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit cbf6f0a with merge base 8afd10e ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function): | |||
# 1. input @ weight_t = output (forward pass) | |||
# 2. grad_output @ weight = grad_input (backward pass) | |||
# 3. input_t @ grad_output = grad_weight (backward pass) | |||
# | |||
# input, weight and grad_output have each their own MX element dtype. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "can have"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
this makes sense, it would be great to cover with a test the easiest place to test it would be here (
MXLinear . Would you be interested in doing that in this PR?
by the way, pytorch/pytorch#146414 outlines bringing MX dtypes to PyTorch core, and we plan to evolve |
…er factory method. Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface of `MXLinear` and `swap_linear_with_mx_linear`. Some additional unit test coverage has been added on MXLinear.
mx_mm
function.mx_mm
function and MXLinear
.
I added the support of this feature in I expanded the coverage in the test you mentioned (plus a small test on the factory side to check the 2 cases above are working properly). Thanks for the link on PyTorch MX plan 👍 I would assume that the MX "simulated" mode is going to stay in TorchAO for some time, as it is very useful for testing + getting ready for MX hardware until it is widely available. |
""" | ||
|
||
@classmethod | ||
@torch.no_grad() | ||
def from_float(cls, mod, elem_dtype, block_size): | ||
mod.__class__ = MXLinear | ||
mod.elem_dtype = elem_dtype | ||
# Single element dtype passed for input, weight and gradient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we do
def from_float(
...,
elem_dtype,
...,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
...
): ...
we plan to create a proper config object for this in the future, but for now would be good to keep things simple and avoid mixing types in the API (such as dtype vs tuple)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I then enforce named argument in MXLinear.from_float
and swap_linear_with_mx_linear
for block_size
and filter_fn
? And have a default block_size=32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed a fix commit with:
def from_float(
cls,
mod,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
block_size=32,
):
and similarly for swap_linear_with_mx_linear
yep! great to hear this is useful. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thank you! Please feel free to merge if CI is green.
note: we will likely change the UX of this workflow in the near future (add a top-level config, etc) as we add Blackwell support, we'll make sure to keep these options in the new UX!
@vkuzo Unfortunately, the H100 Float8 test runner seems to have had an issue starting. |
failure is transient, I merged it. Thank you! |
Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. This PR is simply adding a more general interface to
mx_mm
. A similar choice could be done withMXLinear
General issue: #1666