diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3bc8ded793..2cd34f427d 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1033,6 +1033,7 @@ def _test_lin_weight_subclass_api_impl( @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 35b0107836..6cdd9b148f 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -118,6 +118,26 @@ def forward(self, x): x = self.linear2(x) return x + +def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): + """ + The deprecated implementation for int8 dynamic quant API, used as a reference for + numerics and performance + """ + from torchao.quantization.quant_api import _in_features_greater_than_16 + from torchao.quantization.quant_api import _is_linear + from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight + + if filter_fn is None: + filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( + *args + ) + + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) + class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() @@ -492,8 +512,8 @@ def test_quantized_tensor_subclass_int8(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): - # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + # use multiples of 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") @@ -525,6 +545,44 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # make sure it compiles torch._export.aot_compile(m_unwrapped, example_inputs) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation") + def test_quantized_tensor_subclass_int8_dyn_quant_perf(self): + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m_ref = copy.deepcopy(m) + # setting batch_size to 20 to be compatible with the kernel + example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + + from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors + change_linear_weights_to_int8_dqtensors(m) + + # reference + _ref_change_linear_weights_to_int8_dqtensors(m_ref) + + res = m(*example_inputs) + ref = m_ref(*example_inputs) + + self.assertTrue(torch.equal(res, ref)) + + # perf comparison + from torchao.utils import benchmark_model + # warmup + WARMUP = 5 + RUNS = 100 + input_tensor = example_inputs[0] + m = torch.compile(m, mode='max-autotune', fullgraph=True) + + benchmark_model(m, WARMUP, input_tensor) + elapsed_time = benchmark_model(m, RUNS, input_tensor) + + m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) + benchmark_model(m_ref, WARMUP, input_tensor) + ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor) + + print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}") + self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time) + if __name__ == "__main__": diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index f4b758ddca..f660a759c2 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn): fn(self.zero_point), ) + def _change_shape(self, shape): + return self.__class__( + self.int_data.view(shape), self.scale, self.zero_point + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + if func is aten.view.default: + assert len(args) == 2 + new = args[0]._change_shape(args[1]) + return return_and_correct_aliasing(func, args, kwargs, new) + raise NotImplementedError( f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" ) @@ -245,6 +255,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] + # TODO: fix the unflatten logic return cls(packed_weight, scale_and_zero) def to(self, *args, **kwargs): @@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) + def _change_shape(self, shape, block_size): + return self.__class__( + self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride() + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): # Note: we only added cpu path here for 8da4w, this is for executorch, in the future @@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -@implements_aqt_torch_function(torch.nn.functional.linear) -def functional_linear(*args, **kwargs): - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) +def _quantized_linear_op(input_tensor, weight_qtensor, bias): is_cuda = weight_qtensor.is_cuda is_cpu = weight_qtensor.device == torch.device("cpu") if isinstance(weight_qtensor, AffineQuantizedTensor): @@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs): # if input tensor is quantized, either dispatch to the int8 mm kernel # or just dequantize the input tensor input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) - input_tensor_dtype_is_expected = input_tensor.dtype in [ - torch.float, - torch.bfloat16 - ] if ( is_cuda and input_is_int8 and - input_tensor_dtype_is_expected and + input_tensor.dtype == weight_qtensor.dtype and input_tensor.layout == "plain" and weight_qtensor.layout == "plain" ): @@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs): weight_qtensor.block_size[1] == weight_qtensor.shape[1] and weight_qtensor.layout == "plain" ): - # TODO: enable mps path as well + # TODO: enable cpu and mps efficient path # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale) - else: - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - else: + w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous() + orig_dtype = input_tensor.dtype + y = ( + torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + * weight_qtensor.scale + ) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y.to(orig_dtype) + + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale) + + raise NotImplementedError("No specialized dispatch found for quantized linear op") + + +@implements_aqt_torch_function(torch.nn.functional.linear) +def functional_linear(*args, **kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - @implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) def aten_mm(func, *args, **kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"{func} is not implemented for non floating point input") + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[1], args[2], args[0], ) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(bias, input_tensor, weight_tensor) else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[0], args[1], - None if len(args) == 2 else args[2], + None ) - weight_tensor = weight_qtensor.dequantize() - return func(input_tensor, weight_tensor, bias) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(input_tensor, weight_tensor) @implements_aqt_aten_ops([aten.detach.default]) def detach(func, *args, **kwargs): @@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs): @implements_aqt_aten_ops([aten.t.default]) def t(func, *args, **kwargs): - # TODO: need to implement this - # args[0].transposed = not args[0].transposed - # new = args[0]._change_shape(args[0].shape[::-1]) - # return return_and_correct_aliasing(func, args, kwargs, new) - raise Exception("transpose not implemented yet") + block_size = args[0].block_size + assert len(block_size) == 2 + transposed_block_size = (block_size[1], block_size[0]) + new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size) + return return_and_correct_aliasing(func, args, kwargs, new) to_aq = AffineQuantizedTensor.from_float diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7ec88c7498..907a666492 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -25,7 +25,11 @@ from typing import Any, Callable from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from .utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + unwrap_tensor_subclass, +) from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, @@ -187,9 +191,13 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): *args ) - _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn - ) + if TORCH_VERSION_AFTER_2_4: + quantize(model, get_apply_int8dyn_quant(), filter_fn) + unwrap_tensor_subclass(model, filter_fn) + else: + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): @@ -282,7 +290,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) # apply to modules under block0 submodule def filter_fn(module, fqn): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index ee13512e9f..972699f0bf 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -610,6 +610,7 @@ def __new__( dtype = original_weight_tensor.dtype kwargs["dtype"] = dtype kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device shape = original_weight_tensor.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] @@ -664,6 +665,27 @@ def _apply_fn_to_data(self, fn): self.input_quant_func, ) + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.original_weight_tensor.to(**kwargs), + self.input_quant_func, + ) + def __torch_dispatch__(cls, func, types, args, kwargs): if ( func in [aten.mm.default, aten.addmm.default] @@ -674,25 +696,29 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"need mat1 shape: {args[1].shape} final" f"dim to match mat2 shape: {args[2].shape} first dim " ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[1], args[2], args[0], ) - aqt = self.input_quant_func(input_tensor) - return func(bias, aqt, weight_tensor) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(bias, aqt, original_weight_tensor) else: + # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( f"need mat1 shape: {args[0].shape} final dim" f"to match mat2 shape: {args[1].shape} first dim" ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor = ( args[0], args[1], - None if len(args) == 2 else args[2], ) - aqt = self.input_quant_func(input_tensor) - return func(aqt, weight_tensor, bias) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(aqt, original_weight_tensor) if func is aten.detach.default: return return_and_correct_aliasing( @@ -704,6 +730,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) + if func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + if func is aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + raise NotImplementedError( f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 78a76863f3..e6787b0cf9 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -133,11 +133,14 @@ def right_inverse(self, tensor): def unwrap_tensor_subclass(model, filter_fn=None): for name, child in model.named_children(): + # make sure child.weight is a tensor subclass if ( isinstance(child, torch.nn.Linear) and hasattr(child, "weight") and type(child.weight) is not torch.Tensor and - isinstance(child.weight, torch.Tensor) + type(child.weight) is not torch.nn.Parameter and + isinstance(child.weight, torch.Tensor) and + issubclass(type(child.weight), torch.Tensor) ): parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) unwrap_tensor_subclass(child)