diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 35afeb7959..17a76a750d 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import itertools import pytest import torch @@ -41,13 +42,16 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize( + "elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3) +) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ Smoke test for training linear module with mx weight """ + # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) grad_shape[-1] = 6 @@ -56,7 +60,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): ) m_mx = copy.deepcopy(m) block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -72,7 +76,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad) x_g_sqnr = compute_error(x_ref.grad, x.grad) - if elem_dtype is torch.float8_e4m3fn: + if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn): assert y_sqnr >= 18.0 assert w_g_sqnr >= 18.0 assert x_g_sqnr >= 12.0 @@ -94,7 +98,7 @@ def test_activation_checkpointing(): nn.Linear(6, 6, bias=True, device="cuda"), ) block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size) + swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -130,7 +134,7 @@ def test_linear_compile(elem_dtype, bias, use_autocast): nn.Linear(K, N, bias=bias, device="cuda"), ) block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -219,6 +223,20 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 +def test_mx_linear_input_weight_gradient_dtypes(): + m = nn.Sequential(nn.Linear(32, 32)) + swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) + assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] + assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] + assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] + + m = nn.Sequential(nn.Linear(32, 32)) + swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) + assert m[0].in_elem_dtype == torch.float8_e4m3fn + assert m[0].w_elem_dtype == torch.float8_e4m3fn + assert m[0].grad_elem_dtype == torch.float8_e4m3fn + + def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), @@ -227,7 +245,9 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear(m1, torch.float8_e4m3fn, 32, filter_fn) + swap_linear_with_mx_linear( + m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn + ) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index b750c26af2..32f45e3755 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -2,8 +2,8 @@ This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch. -Note that the current version of the code is written for readability and -numerical correctness and not yet for optimal performance. We welcome +Note that the current version of the code is written for readability and +numerical correctness and not yet for optimal performance. We welcome contributions on performance improvements. Note that there are no BC guarantees at the moment and we plan to evolve @@ -44,8 +44,7 @@ from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_linear(m, elem_dtype, block_size) +swap_linear_with_mx_linear(m, elem_dtype, block_size=32) # training loop (not shown) ``` @@ -93,7 +92,7 @@ python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ## floating point format convenience functions -We have a convenience script which summarizes the various properties of +We have a convenience script which summarizes the various properties of floating point formats: ```bash diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index b69441e018..d7aa744334 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function): # 1. input @ weight_t = output (forward pass) # 2. grad_output @ weight = grad_input (backward pass) # 3. input_t @ grad_output = grad_weight (backward pass) + # + # input, weight and grad_output can have each their own MX element dtype. @staticmethod def forward( ctx, input_hp: torch.Tensor, weight_hp: torch.Tensor, - elem_dtype: Any, + in_elem_dtype: Any, + w_elem_dtype: Any, + grad_elem_dtype: Any, block_size: int, ): ctx.save_for_backward(input_hp, weight_hp) - ctx.elem_dtype = elem_dtype + ctx.in_elem_dtype = in_elem_dtype + ctx.w_elem_dtype = w_elem_dtype + ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) - input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size) - weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size) + input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size) + weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -51,7 +57,9 @@ def forward( def backward(ctx, grad_output_hp: torch.Tensor): input_hp, weight_hp = ctx.saved_tensors weight_hp_t_c = weight_hp.t().contiguous() - elem_dtype = ctx.elem_dtype + in_elem_dtype = ctx.in_elem_dtype + w_elem_dtype = ctx.w_elem_dtype + grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size grad_output_orig_shape = grad_output_hp.shape @@ -61,8 +69,10 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1]) # grad_output @ weight = grad_input - grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size) - weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size) + grad_output_mx_dim0 = MXTensor.to_mx( + grad_output_hp_r, grad_elem_dtype, block_size + ) + weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -70,15 +80,15 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), elem_dtype, block_size + grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size ) input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), elem_dtype, block_size + input_hp_r.t().contiguous(), in_elem_dtype, block_size ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None + return grad_input, grad_weight, None, None, None, None class MXLinear(torch.nn.Linear): @@ -87,13 +97,25 @@ class MXLinear(torch.nn.Linear): matmul is emulated since there is no hardware support yet. Activations, weights and grads are casted to MX and back to high precision for each matmul. + + Input, weight and grad_output can have each their own MX element dtype. """ @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float( + cls, + mod, + elem_dtype, + elem_dtype_weight_override=None, + elem_dtype_grad_output_override=None, + *, + block_size=32, + ): mod.__class__ = MXLinear - mod.elem_dtype = elem_dtype + mod.in_elem_dtype = elem_dtype + mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype + mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size return mod @@ -106,7 +128,14 @@ def forward(self, x): else: w = self.weight - y = mx_mm.apply(x, w, self.elem_dtype, self.block_size) + y = mx_mm.apply( + x, + w, + self.in_elem_dtype, + self.w_elem_dtype, + self.grad_elem_dtype, + self.block_size, + ) if self.bias is not None: y = y + self.bias return y @@ -172,7 +201,15 @@ def _is_linear(mod, fqn): return isinstance(mod, torch.nn.Linear) -def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None): +def swap_linear_with_mx_linear( + model, + elem_dtype, + elem_dtype_weight_override=None, + elem_dtype_grad_output_override=None, + *, + block_size=32, + filter_fn=None, +): if filter_fn is None: combined_filter_fn = _is_linear else: @@ -183,7 +220,13 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXLinear.from_float( + mod, + elem_dtype, + elem_dtype_weight_override, + elem_dtype_grad_output_override, + block_size=block_size, + ), combined_filter_fn, )