From 6358a71dcfde6c6361b41b6c53d9054f8ef5f808 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 5 Feb 2025 14:04:38 +0000 Subject: [PATCH 1/3] Support mixed MX element dtype in `mx_mm` function. Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. --- torchao/prototype/mx_formats/mx_linear.py | 28 +++++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index b69441e018..f22ae2e71d 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function): # 1. input @ weight_t = output (forward pass) # 2. grad_output @ weight = grad_input (backward pass) # 3. input_t @ grad_output = grad_weight (backward pass) + # + # input, weight and grad_output have each their own MX element dtype. @staticmethod def forward( ctx, input_hp: torch.Tensor, weight_hp: torch.Tensor, - elem_dtype: Any, + in_elem_dtype: Any, + w_elem_dtype: Any, + grad_elem_dtype: Any, block_size: int, ): ctx.save_for_backward(input_hp, weight_hp) - ctx.elem_dtype = elem_dtype + ctx.in_elem_dtype = in_elem_dtype + ctx.w_elem_dtype = w_elem_dtype + ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) - input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size) - weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size) + input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size) + weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -51,7 +57,9 @@ def forward( def backward(ctx, grad_output_hp: torch.Tensor): input_hp, weight_hp = ctx.saved_tensors weight_hp_t_c = weight_hp.t().contiguous() - elem_dtype = ctx.elem_dtype + in_elem_dtype = ctx.in_elem_dtype + w_elem_dtype = ctx.w_elem_dtype + grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size grad_output_orig_shape = grad_output_hp.shape @@ -61,8 +69,8 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1]) # grad_output @ weight = grad_input - grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size) - weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size) + grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, grad_elem_dtype, block_size) + weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -70,15 +78,15 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), elem_dtype, block_size + grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size ) input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), elem_dtype, block_size + input_hp_r.t().contiguous(), in_elem_dtype, block_size ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None + return grad_input, grad_weight, None, None, None, None class MXLinear(torch.nn.Linear): From 5c8eb6dee5e6ab994ddac22ee028d8c35d496df6 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 5 Feb 2025 16:29:39 +0000 Subject: [PATCH 2/3] Support (input, weight, gradient) element dtype tuple in MXLinear layer factory method. Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface of `MXLinear` and `swap_linear_with_mx_linear`. Some additional unit test coverage has been added on MXLinear. --- test/prototype/mx_formats/test_mx_linear.py | 22 +++++++++++++++-- torchao/prototype/mx_formats/mx_linear.py | 27 +++++++++++++++++---- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 35afeb7959..e1b000d54a 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import itertools import pytest import torch @@ -41,13 +42,16 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize( + "elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3) +) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ Smoke test for training linear module with mx weight """ + # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) grad_shape[-1] = 6 @@ -72,7 +76,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad) x_g_sqnr = compute_error(x_ref.grad, x.grad) - if elem_dtype is torch.float8_e4m3fn: + if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn): assert y_sqnr >= 18.0 assert w_g_sqnr >= 18.0 assert x_g_sqnr >= 12.0 @@ -219,6 +223,20 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 +def test_mx_linear_input_weight_gradient_dtypes(): + m = nn.Sequential(nn.Linear(32, 32)) + swap_linear_with_mx_linear(m, tuple(SUPPORTED_ELEM_DTYPES[:3]), 32) + assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] + assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] + assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] + + m = nn.Sequential(nn.Linear(32, 32)) + swap_linear_with_mx_linear(m, torch.float8_e4m3fn, 32) + assert m[0].in_elem_dtype == torch.float8_e4m3fn + assert m[0].w_elem_dtype == torch.float8_e4m3fn + assert m[0].grad_elem_dtype == torch.float8_e4m3fn + + def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index f22ae2e71d..3706c30d15 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -23,8 +23,8 @@ class mx_mm(torch.autograd.Function): # 1. input @ weight_t = output (forward pass) # 2. grad_output @ weight = grad_input (backward pass) # 3. input_t @ grad_output = grad_weight (backward pass) - # - # input, weight and grad_output have each their own MX element dtype. + # + # input, weight and grad_output can have each their own MX element dtype. @staticmethod def forward( @@ -69,7 +69,9 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1]) # grad_output @ weight = grad_input - grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, grad_elem_dtype, block_size) + grad_output_mx_dim0 = MXTensor.to_mx( + grad_output_hp_r, grad_elem_dtype, block_size + ) weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( @@ -95,13 +97,20 @@ class MXLinear(torch.nn.Linear): matmul is emulated since there is no hardware support yet. Activations, weights and grads are casted to MX and back to high precision for each matmul. + + Input, weight and grad_output can have each their own MX element dtype + by passing a tuple of `elem_dtype` to the factory method `from_float`. """ @classmethod @torch.no_grad() def from_float(cls, mod, elem_dtype, block_size): mod.__class__ = MXLinear - mod.elem_dtype = elem_dtype + # Single element dtype passed for input, weight and gradient. + if not isinstance(elem_dtype, (tuple, list)): + elem_dtype = (elem_dtype, elem_dtype, elem_dtype) + # Unpack input, weight and gradient element dtypes. + mod.in_elem_dtype, mod.w_elem_dtype, mod.grad_elem_dtype = elem_dtype mod.block_size = block_size return mod @@ -114,7 +123,14 @@ def forward(self, x): else: w = self.weight - y = mx_mm.apply(x, w, self.elem_dtype, self.block_size) + y = mx_mm.apply( + x, + w, + self.in_elem_dtype, + self.w_elem_dtype, + self.grad_elem_dtype, + self.block_size, + ) if self.bias is not None: y = y + self.bias return y @@ -181,6 +197,7 @@ def _is_linear(mod, fqn): def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None): + # `elem_dtype` can be a single dtype or a tuple of 3 for (input, weight, gradient). if filter_fn is None: combined_filter_fn = _is_linear else: From cbf6f0a18783676e3405ad3df21fd9178a45bd0d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 5 Feb 2025 20:36:40 +0000 Subject: [PATCH 3/3] Using default `elem_dtype` argument and optional weight/grad overrides. --- test/prototype/mx_formats/test_mx_linear.py | 14 ++++---- torchao/prototype/mx_formats/README.md | 9 +++-- torchao/prototype/mx_formats/mx_linear.py | 40 +++++++++++++++------ 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index e1b000d54a..17a76a750d 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -60,7 +60,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): ) m_mx = copy.deepcopy(m) block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -98,7 +98,7 @@ def test_activation_checkpointing(): nn.Linear(6, 6, bias=True, device="cuda"), ) block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size) + swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -134,7 +134,7 @@ def test_linear_compile(elem_dtype, bias, use_autocast): nn.Linear(K, N, bias=bias, device="cuda"), ) block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -225,13 +225,13 @@ def test_inference_compile_simple(elem_dtype): def test_mx_linear_input_weight_gradient_dtypes(): m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, tuple(SUPPORTED_ELEM_DTYPES[:3]), 32) + swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, torch.float8_e4m3fn, 32) + swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) assert m[0].in_elem_dtype == torch.float8_e4m3fn assert m[0].w_elem_dtype == torch.float8_e4m3fn assert m[0].grad_elem_dtype == torch.float8_e4m3fn @@ -245,7 +245,9 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear(m1, torch.float8_e4m3fn, 32, filter_fn) + swap_linear_with_mx_linear( + m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn + ) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index b750c26af2..32f45e3755 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -2,8 +2,8 @@ This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch. -Note that the current version of the code is written for readability and -numerical correctness and not yet for optimal performance. We welcome +Note that the current version of the code is written for readability and +numerical correctness and not yet for optimal performance. We welcome contributions on performance improvements. Note that there are no BC guarantees at the moment and we plan to evolve @@ -44,8 +44,7 @@ from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_linear(m, elem_dtype, block_size) +swap_linear_with_mx_linear(m, elem_dtype, block_size=32) # training loop (not shown) ``` @@ -93,7 +92,7 @@ python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ## floating point format convenience functions -We have a convenience script which summarizes the various properties of +We have a convenience script which summarizes the various properties of floating point formats: ```bash diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 3706c30d15..d7aa744334 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -98,19 +98,24 @@ class MXLinear(torch.nn.Linear): weights and grads are casted to MX and back to high precision for each matmul. - Input, weight and grad_output can have each their own MX element dtype - by passing a tuple of `elem_dtype` to the factory method `from_float`. + Input, weight and grad_output can have each their own MX element dtype. """ @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float( + cls, + mod, + elem_dtype, + elem_dtype_weight_override=None, + elem_dtype_grad_output_override=None, + *, + block_size=32, + ): mod.__class__ = MXLinear - # Single element dtype passed for input, weight and gradient. - if not isinstance(elem_dtype, (tuple, list)): - elem_dtype = (elem_dtype, elem_dtype, elem_dtype) - # Unpack input, weight and gradient element dtypes. - mod.in_elem_dtype, mod.w_elem_dtype, mod.grad_elem_dtype = elem_dtype + mod.in_elem_dtype = elem_dtype + mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype + mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size return mod @@ -196,8 +201,15 @@ def _is_linear(mod, fqn): return isinstance(mod, torch.nn.Linear) -def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None): - # `elem_dtype` can be a single dtype or a tuple of 3 for (input, weight, gradient). +def swap_linear_with_mx_linear( + model, + elem_dtype, + elem_dtype_weight_override=None, + elem_dtype_grad_output_override=None, + *, + block_size=32, + filter_fn=None, +): if filter_fn is None: combined_filter_fn = _is_linear else: @@ -208,7 +220,13 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXLinear.from_float( + mod, + elem_dtype, + elem_dtype_weight_override, + elem_dtype_grad_output_override, + block_size=block_size, + ), combined_filter_fn, )