Skip to content

Support mixed MX element dtype in mx_mm function and MXLinear. #1667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

balancap
Copy link
Contributor

@balancap balancap commented Feb 5, 2025

Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. This PR is simply adding a more general interface to mx_mm. A similar choice could be done with MXLinear

General issue: #1666

Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.
Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1667

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit cbf6f0a with merge base 8afd10e (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# input, weight and grad_output have each their own MX element dtype.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "can have"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@vkuzo
Copy link
Contributor

vkuzo commented Feb 5, 2025

this makes sense, it would be great to cover with a test

the easiest place to test it would be here (

def test_linear_eager(elem_dtype, bias, input_shape):
), and that requires adding this to MXLinear. Would you be interested in doing that in this PR?

by the way, pytorch/pytorch#146414 outlines bringing MX dtypes to PyTorch core, and we plan to evolve torchao/prototype/mx_formats/ accordingly

@vkuzo vkuzo added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Feb 5, 2025
…er factory method.

Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface
of `MXLinear` and `swap_linear_with_mx_linear`.

Some additional unit test coverage has been added on MXLinear.
@balancap balancap changed the title Support mixed MX element dtype in mx_mm function. Support mixed MX element dtype in mx_mm function and MXLinear. Feb 5, 2025
@balancap
Copy link
Contributor Author

balancap commented Feb 5, 2025

I added the support of this feature in MXLinear too. In order to avoid breaking the interface (and keeping things simple in the single dtype case), you can now pass either a single element dtype or a tuple of 3.

I expanded the coverage in the test you mentioned (plus a small test on the factory side to check the 2 cases above are working properly).

Thanks for the link on PyTorch MX plan 👍 I would assume that the MX "simulated" mode is going to stay in TorchAO for some time, as it is very useful for testing + getting ready for MX hardware until it is widely available.

"""

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size):
mod.__class__ = MXLinear
mod.elem_dtype = elem_dtype
# Single element dtype passed for input, weight and gradient.
Copy link
Contributor

@vkuzo vkuzo Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we do

def from_float(
    ...,
    elem_dtype,
    ...,
    elem_dtype_weight_override=None,
    elem_dtype_grad_output_override=None,
    ...
): ...

we plan to create a proper config object for this in the future, but for now would be good to keep things simple and avoid mixing types in the API (such as dtype vs tuple)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I then enforce named argument in MXLinear.from_float and swap_linear_with_mx_linear for block_size and filter_fn? And have a default block_size=32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed a fix commit with:

def from_float(
        cls,
        mod,
        elem_dtype,
        elem_dtype_weight_override=None,
        elem_dtype_grad_output_override=None,
        *,
        block_size=32,
    ):

and similarly for swap_linear_with_mx_linear

@vkuzo
Copy link
Contributor

vkuzo commented Feb 5, 2025

I would assume that the MX "simulated" mode is going to stay in TorchAO for some time, as it is very useful for testing + getting ready for MX hardware until it is widely available.

yep! great to hear this is useful.

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thank you! Please feel free to merge if CI is green.

note: we will likely change the UX of this workflow in the near future (add a top-level config, etc) as we add Blackwell support, we'll make sure to keep these options in the new UX!

@balancap
Copy link
Contributor Author

balancap commented Feb 6, 2025

@vkuzo Unfortunately, the H100 Float8 test runner seems to have had an issue starting.

@vkuzo vkuzo merged commit 1d75c8f into pytorch:main Feb 6, 2025
16 of 17 checks passed
@vkuzo
Copy link
Contributor

vkuzo commented Feb 6, 2025

failure is transient, I merged it. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants