From 89293588f1e3a1c50c71e9b1ee89fab80fb26049 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 13 Jun 2024 19:31:59 -0700 Subject: [PATCH] Factor out dispatch and layout registration table Summary: att, after the refactor we can use common utils for new dtypes as well Test Plan: python test/dtypes/test_nf4.py python test/dtypes/test_aqt.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags: --- .../{test_aq.py => test_affine_quantized.py} | 2 +- torchao/dtypes/aqt.py | 527 ++++++++---------- torchao/dtypes/nf4tensor.py | 519 +++++++++-------- torchao/dtypes/utils.py | 65 +++ 4 files changed, 559 insertions(+), 554 deletions(-) rename test/dtypes/{test_aq.py => test_affine_quantized.py} (95%) create mode 100644 torchao/dtypes/utils.py diff --git a/test/dtypes/test_aq.py b/test/dtypes/test_affine_quantized.py similarity index 95% rename from test/dtypes/test_aq.py rename to test/dtypes/test_affine_quantized.py index f4211c8921..76b07f2e7a 100644 --- a/test/dtypes/test_aq.py +++ b/test/dtypes/test_affine_quantized.py @@ -7,7 +7,7 @@ import unittest -class TestAQ(TestCase): +class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_tensor_core_layout_transpose(self): t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda") diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 8091946f77..4d83c92050 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -15,6 +15,12 @@ ) from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import find_multiple +from torchao.dtypes.utils import ( + _implements, + _ATEN_OP_OR_TORCH_FN_TABLE, + _register_layout_cls, + _get_layout_tensor_constructor, +) aten = torch.ops.aten @@ -42,57 +48,6 @@ def _aqt_is_uint4(aqt): aqt.quant_max is None or aqt.quant_max == 15 ) -# TODO: merge with nf4 implements decorator -# aten op to their __torch_dispatch__ implemnetations for the tensor subclass -_ATEN_OPS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) - -def implements_aten_ops(cls, aten_ops): - """Use this decorator to implement a function for an aten op in __torch_dispatch__""" - - def decorator(func): - for op in aten_ops: - _ATEN_OPS_TABLE[cls][op] = func - return func - - return decorator - -_TORCH_FUNCTIONS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) - -def implements_torch_function(cls, torch_function): - def decorator(func): - functools.update_wrapper(func, torch_function) - _TORCH_FUNCTIONS_TABLE[cls][torch_function] = func - return func - - return decorator - -def implements_aqt_aten_ops(aten_ops): - return implements_aten_ops(AffineQuantizedTensor, aten_ops) - -def implements_aqt_torch_function(torch_function): - return implements_torch_function(AffineQuantizedTensor, torch_function) - -""" -dict mapping from aqt layout type to the corresponding constructor (AQTLayout.from_plain) -""" -_AQT_LAYOUT_TO_CTR: Dict[str, Callable] = {} - -def register_aqt_layout_cls(extended_layout: str): - """ Register AQTLayout class - """ - def decorator(layout_cls): - layout_cls.extended_layout = extended_layout - _AQT_LAYOUT_TO_CTR[extended_layout] = layout_cls.from_plain - return layout_cls - return decorator - -def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable: - """Get Layout class constructor (LayoutClass.from_plain) for AffineQuantizedTensor - """ - if extended_layout not in _AQT_LAYOUT_TO_CTR: - raise ValueError(f"extended_layout: {extended_layout} is not supported yet") - return _AQT_LAYOUT_TO_CTR.get(extended_layout) - class AQTLayout(torch.Tensor): """ Base class for the layout tensor for `AffineQuantizedTensor` @@ -126,7 +81,237 @@ def _get_to_kwargs(self, *args, **kwargs): } return kwargs -@register_aqt_layout_cls("plain") +class AffineQuantizedTensor(torch.Tensor): + """ + Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + + fields: + layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, + e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device + and operator/kernel + block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + shape (torch.Size): the shape for the Tensor + quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object + dtype: dtype for external representation of the tensor, e.g. torch.float32 + """ + + @staticmethod + def __new__( + cls, + layout_tensor: AQTLayout, + block_size: Tuple[int, ...], + shape: torch.Size, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + dtype=None, + strides=None, + ): + kwargs = {} + kwargs["device"] = layout_tensor.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout + ) + kwargs["dtype"] = dtype + if strides is not None: + kwargs["strides"] = strides + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + layout_tensor: AQTLayout, + block_size: Tuple[int, ...], + shape: torch.Size, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + dtype=None, + strides=None, + ): + self.layout_tensor = layout_tensor + self.block_size = block_size + self.quant_min = quant_min + self.quant_max = quant_max + self.zero_point_domain = zero_point_domain + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def dequantize(self, output_dtype=None): + if output_dtype is None: + output_dtype = self.dtype + int_data, scale, zero_point = self.layout_tensor.get_plain() + return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + + def __tensor_flatten__(self): + return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + layout_tensor = tensor_data_dict["layout_tensor"] + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes + return cls( + layout_tensor, + block_size, + shape if outer_size is None else outer_size, + quant_min, + quant_max, + zero_point_domain, + dtype=dtype, + strides=outer_stride, + ) + + @classmethod + def from_float( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + extended_layout: str = "plain", + # TODO: this is only for "tensor_core_tiled", need to figure out + # the proper API for this arg + inner_k_tiles: Optional[int] = None, + ): + original_shape = input_float.shape + if extended_layout == "tensor_core_tiled": + orig_out_features, orig_in_features = input_float.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input_float = torch.nn.functional.pad( + input_float, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + + layout_cls_ctr = get_layout_tensor_constructor(extended_layout) + # TODO: this is temporary, need to come up with the proper UX + if extended_layout == "tensor_core_tiled": + layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles) + else: + layout_tensor = layout_cls_ctr(int_data, scale, zero_point) + return cls( + layout_tensor, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype + ) + + @property + def layout(self) -> str: + return self.layout_tensor.extended_layout + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.layout_tensor.to(kwargs["device"]), + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.layout_tensor), + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # Note: we only added cpu path here for 8da4w, this is for executorch, in the future + # 1. we'll add cpu/cuda version (int4mm etc.) + # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like + # cpu device + et laytout --> gives current 8da4w executorch representation + # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. + # cuda device + some layout --> gives cuda kernel + + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized + # kernels in CPU as well, see the note above + # 2 - we're given non-floats - quantizing long to int8 is crazy + + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError( + f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + ) + +def implements(aten_ops_or_torch_fn): + return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn) + +def register_layout_cls(extended_layout: str): + return _register_layout_cls(AffineQuantizedTensor, extended_layout) + +def get_layout_tensor_constructor(extended_layout: str): + return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout) + +@register_layout_cls("plain") class PlainAQTLayout(AQTLayout): """ Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point @@ -222,7 +407,7 @@ def from_plain( ): return cls(int_data, scale, zero_point) -@register_aqt_layout_cls("tensor_core_tiled") +@register_layout_cls("tensor_core_tiled") class TensorCoreTiledAQTLayout(AQTLayout): """ Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, @@ -347,230 +532,6 @@ def get_plain(self): int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) return int_data, scale, zero -class AffineQuantizedTensor(torch.Tensor): - """ - Base affine quantized tensor subclass. When the from_float method is used, - to create an instance of any AffineQuantizedTensor - - The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, - regardless of the internal representation's type or orientation. - - Affine quantization means we quantize the floating point tensor with an affine transformation: - quantized_tensor = float_tensor / scale + zero_point - - fields: - layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, - e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device - and operator/kernel - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - shape (torch.Size): the shape for the Tensor - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object - dtype: dtype for external representation of the tensor, e.g. torch.float32 - """ - - @staticmethod - def __new__( - cls, - layout_tensor: AQTLayout, - block_size: Tuple[int, ...], - shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - dtype=None, - strides=None, - ): - kwargs = {} - kwargs["device"] = layout_tensor.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout - ) - kwargs["dtype"] = dtype - if strides is not None: - kwargs["strides"] = strides - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - layout_tensor: AQTLayout, - block_size: Tuple[int, ...], - shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - dtype=None, - strides=None, - ): - self.layout_tensor = layout_tensor - self.block_size = block_size - self.quant_min = quant_min - self.quant_max = quant_max - self.zero_point_domain = zero_point_domain - - def __repr__(self): - return ( - f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def dequantize(self, output_dtype=None): - if output_dtype is None: - output_dtype = self.dtype - int_data, scale, zero_point = self.layout_tensor.get_plain() - return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) - - def __tensor_flatten__(self): - return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - layout_tensor = tensor_data_dict["layout_tensor"] - block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes - return cls( - layout_tensor, - block_size, - shape if outer_size is None else outer_size, - quant_min, - quant_max, - zero_point_domain, - dtype=dtype, - strides=outer_stride, - ) - - @classmethod - def from_float( - cls, - input_float: torch.Tensor, - mapping_type: MappingType, - block_size: Tuple[int, ...], - target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - extended_layout: str = "plain", - # TODO: this is only for "tensor_core_tiled", need to figure out - # the proper API for this arg - inner_k_tiles: Optional[int] = None, - ): - original_shape = input_float.shape - if extended_layout == "tensor_core_tiled": - orig_out_features, orig_in_features = input_float.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input_float = torch.nn.functional.pad( - input_float, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) - - layout_cls_ctr = get_aqt_layout_cls_ctr(extended_layout) - # TODO: this is temporary, need to come up with the proper UX - if extended_layout == "tensor_core_tiled": - layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles) - else: - layout_tensor = layout_cls_ctr(int_data, scale, zero_point) - return cls( - layout_tensor, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype - ) - - @property - def layout(self) -> str: - return self.layout_tensor.extended_layout - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func in _TORCH_FUNCTIONS_TABLE[cls]: - return _TORCH_FUNCTIONS_TABLE[cls][func](*args, **kwargs) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.layout_tensor.to(kwargs["device"]), - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.layout_tensor), - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - dtype=self.dtype, - strides=self.stride(), - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - # Note: we only added cpu path here for 8da4w, this is for executorch, in the future - # 1. we'll add cpu/cuda version (int4mm etc.) - # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like - # cpu device + et laytout --> gives current 8da4w executorch representation - # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. - # cuda device + some layout --> gives cuda kernel - - # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized - # kernels in CPU as well, see the note above - # 2 - we're given non-floats - quantizing long to int8 is crazy - - if func in _ATEN_OPS_TABLE[cls]: - return _ATEN_OPS_TABLE[cls][func](func, *args, **kwargs) - - raise NotImplementedError( - f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" - ) - def _quantized_linear_op(input_tensor, weight_qtensor, bias): """ Quantized version of F.linear operator @@ -705,7 +666,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): raise NotImplementedError("No specialized dispatch found for quantized linear op") -@implements_aqt_torch_function(torch.nn.functional.linear) +@implements(torch.nn.functional.linear) def functional_linear(*args, **kwargs): input_tensor, weight_tensor, bias = ( args[0], @@ -724,7 +685,7 @@ def functional_linear(*args, **kwargs): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) -@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) +@implements([aten.mm.default, aten.addmm.default]) def aten_mm(func, *args, **kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"{func} is not implemented for non floating point input") @@ -763,21 +724,21 @@ def aten_mm(func, *args, **kwargs): weight_tensor = weight_tensor.dequantize() return func(input_tensor, weight_tensor) -@implements_aqt_aten_ops([aten.detach.default]) +@implements([aten.detach.default]) def detach(func, *args, **kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) -@implements_aqt_aten_ops([aten.clone.default]) +@implements([aten.clone.default]) def clone(func, *args, **kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) -@implements_aqt_aten_ops([aten._to_copy.default]) +@implements([aten._to_copy.default]) def _to_copy(func, *args, **kwargs): return return_and_correct_aliasing( func, @@ -786,7 +747,7 @@ def _to_copy(func, *args, **kwargs): args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) -@implements_aqt_aten_ops([aten.t.default]) +@implements([aten.t.default]) def t(func, *args, **kwargs): block_size = args[0].block_size assert len(block_size) == 2 diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 84c398426f..f05599f6ef 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -11,6 +11,10 @@ from torch import Tensor from torch.distributed.device_mesh import DeviceMesh from torch._prims_common import make_contiguous_strides_for +from torchao.dtypes.utils import ( + _implements, + _ATEN_OP_OR_TORCH_FN_TABLE, +) aten = torch.ops.aten @@ -19,9 +23,6 @@ from typing import Any, Optional, Tuple, Union, List -NF4_OPS_TABLE: Dict[Any, Any] = {} - - _INNER_TENSOR_NAMES_FOR_SHARDING = ["quantized_scalers", "quantization_factor", "quantized_data"] # Note: Quantize in Chunks @@ -43,17 +44,6 @@ def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): ) -def implements(aten_ops): - """Use this decorator to implement a function for an aten op in __torch_dispatch__""" - - def decorator(func): - for op in aten_ops: - NF4_OPS_TABLE[op] = func - return func - - return decorator - - def construct_nf4_args(nf4tensor: "NF4Tensor", kwargs: Optional[Dict[str, Any]] = None): if kwargs is None: kwargs = {} @@ -131,248 +121,6 @@ def wrapper(aten_op, args, kwargs=None): return decorator -@implements([torch.ops.aten.detach]) -def noop_detach(func, *args, **kwargs): - return args[0][0] - - -@implements( - [ - aten.detach.default, - ] -) -def nf4_detach(aten_op, args, kwargs=None): - nf4tensor = args[0] - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - - -@implements( - [ - aten.split.Tensor, - ] -) -def nf4_split(aten_op, args, kwargs=None): - if len(args) == 3 and args[2] != 0: - raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") - nf4tensor = args[0] - num_chunks = nf4tensor.size(0) // args[1] - - attr_to_chunks = {} - for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: - inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.numel() % num_chunks == 0, f"{attr}.numel() not divisible by {num_chunks}" - chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) - attr_to_chunks[attr] = chunks - - orig_dim = nf4tensor.dim() - if orig_dim == 1: - chunked_size = (nf4tensor.size(0) // num_chunks, ) - elif orig_dim == 2: - chunked_size = (nf4tensor.size(0) // num_chunks, nf4tensor.size(1)) - else: - chunked_size = () - raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") - - nf4_chunks = [] - for idx in range(num_chunks): - updated_attrs = { - "size": chunked_size - } - for attr, chunks in attr_to_chunks.items(): - updated_attrs[attr] = chunks[idx] - nf4_chunks.append(NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))) - return nf4_chunks - -@implements( - [ - aten.new_zeros.default, - ] -) -@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") -def nf4_new_zeros(aten_op, args, kwargs=None): - nf4tensor = args[0] - new_size = tuple(args[1]) - new_size_dim = len(new_size) - if nf4tensor.numel() % math.prod(new_size) != 0: - raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") - ratio = nf4tensor.numel() // math.prod(new_size) - - updated_attrs = {} - for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: - inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.size(0) % ratio == 0, f"{attr}.numel() must be divisible by {ratio}" - inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) - updated_attrs[attr] = inner_tensor - updated_attrs["size"] = new_size - - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - -@implements( - [ - aten.slice.Tensor, - ] -) -@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") -@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") -@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") -def nf4_slice(aten_op, args, kwargs=None): - nf4tensor = args[0] - # for tensor 512 x 512, tensor[:, :512] dispatch to - # aten.slice(dim = 0, end=sys.maxsize) - if not args[3] in [nf4tensor.size(0), sys.maxsize]: - raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") - return NF4Tensor(*construct_nf4_args(nf4tensor)) - -@implements( - [ - aten.view.default, - ] -) -@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") -def nf4_view(aten_op, args, kwargs=None): - nf4tensor = args[0] - size = args[1] - if size[0] != -1: - raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - updated_attrs.update({ - "size": [nf4tensor.numel()], - "stride": (1, ), - }) - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - -@implements( - [ - aten.as_strided.default, - ] -) -@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") -def nf4_as_strided(aten_op, args, kwargs=None): - nf4tensor = args[0] - size = args[1] - stride = tuple(args[2]) - storage_offset = args[3] - if math.prod(size) != nf4tensor.numel(): - raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") - if stride != make_contiguous_strides_for(size): - raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") - if nf4tensor.storage_offset() != storage_offset: - raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") - kwargs = { - "size": torch.Size(size), - "stride": stride, - "storage_offset": storage_offset, - } - return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) - - -@implements([torch.ops.aten._to_copy.default]) -def _to_copy(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return func(args[0][0].t()).t() - out = args[0][0].get_original_weight().to(args[1]["dtype"]) - if "device" in args[1]: - out = out.to(args[1]["device"]) - return out - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() - return args[0][0].get_original_weight().to(args[0][1]) - - -@implements([torch.ops.aten.t.default]) -def t_default(func, *args, **kwargs): - a = args[0][0] - tensor_meta = SubclassTensorArgs( - a.size(), - (a.stride(1), a.stride(0)), - a.storage_offset(), - a.dtype, - a.device, - a.requires_grad, - ) - b = NF4Tensor( - tensor_meta, - a.block_size, - a.n_blocks, - a.scaler_block_size, - a.quantized_scalers, - a.quantization_factor, - a.scaler_mean, - a.quantized_data, - a.nf4, - ) - return b - - -@implements([torch.ops.aten.mm.default]) -def mm_default(func, *args, **kwargs): - return linear_nf4(args[0][0], args[0][1]) - - -@implements( - [ - aten.copy_.default, - ] -) -def copy_(func, *args, **kwargs): - original: NF4Tensor = args[0][0] - copy_in: torch.Tensor = args[0][1] - - # Base Case - - if same_metadata(original, copy_in): - original_tensors = original.__tensor_flatten__()[0] - for tensor_name in original_tensors: - getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) - return - - # Convert Non NF4Tensor into NF4 for copy in - if not isinstance(copy_in, NF4Tensor): - copy_in_nf4 = NF4Tensor.from_tensor( - copy_in, original.block_size, original.scaler_block_size - ) - return original.copy_(copy_in_nf4) - - # Other Tensor is not a NF4Tensor - full_precision = copy_in.get_original_weight() - same_meta_nf4 = NF4Tensor.from_tensor( - full_precision, original.block_size, original.scaler_block_size - ) - return original.copy_(same_meta_nf4) - - -@implements( - [ - aten.is_pinned.default, - ] -) -def nf4_is_pinned(aten_op, args, kwargs=None): - nf4tensor = args[0] - for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: - inner_tensor = getattr(nf4tensor, attr) - if not aten_op(inner_tensor, *(args[1:]), **kwargs): - return False - return True - - -@implements( - [ - aten._pin_memory.default, - ] -) -def nf4_pin_memory(aten_op, args, kwargs=None): - nf4tensor = args[0] - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - - @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -759,7 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment instead we have created a Autograd.Function to handle the linear """ - # All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs + # All ops in the _ATEN_OP_OR_TORCH_FN_TABLE expect NF4 Tensors as inputs # And don't support mixed tensor subclasses. This will trigger the handler for # the next type in the dispatch list @@ -775,8 +523,8 @@ def allowed_subclasses(type): if not all(allowed_subclasses(t) for t in types): return NotImplemented("Up to the next one to handle") - if func in NF4_OPS_TABLE: - return NF4_OPS_TABLE[func](func, args, kwargs) + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, args, kwargs) raise NotImplementedError( f"NF4Tensor dispatch: attempting to run {func}, this is not supported" ) @@ -789,8 +537,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} try: - if func in NF4_TORCH_FUNCTIONS: - return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) except NotImplementedError: pass @@ -886,20 +634,251 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) +def implements(aten_ops_or_torch_fn): + return _implements(NF4Tensor, aten_ops_or_torch_fn) -NF4_TORCH_FUNCTIONS = {} +@implements([torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + return args[0][0] -def implements_torch_function(torch_function): - def decorator(func): - functools.update_wrapper(func, torch_function) - NF4_TORCH_FUNCTIONS[torch_function] = func - return func +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) - return decorator +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + if len(args) == 3 and args[2] != 0: + raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") + nf4tensor = args[0] + num_chunks = nf4tensor.size(0) // args[1] + + attr_to_chunks = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.numel() % num_chunks == 0, f"{attr}.numel() not divisible by {num_chunks}" + chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) + attr_to_chunks[attr] = chunks + + orig_dim = nf4tensor.dim() + if orig_dim == 1: + chunked_size = (nf4tensor.size(0) // num_chunks, ) + elif orig_dim == 2: + chunked_size = (nf4tensor.size(0) // num_chunks, nf4tensor.size(1)) + else: + chunked_size = () + raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") + + nf4_chunks = [] + for idx in range(num_chunks): + updated_attrs = { + "size": chunked_size + } + for attr, chunks in attr_to_chunks.items(): + updated_attrs[attr] = chunks[idx] + nf4_chunks.append(NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))) + return nf4_chunks + +@implements( + [ + aten.new_zeros.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") +def nf4_new_zeros(aten_op, args, kwargs=None): + nf4tensor = args[0] + new_size = tuple(args[1]) + new_size_dim = len(new_size) + if nf4tensor.numel() % math.prod(new_size) != 0: + raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") + ratio = nf4tensor.numel() // math.prod(new_size) + + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.size(0) % ratio == 0, f"{attr}.numel() must be divisible by {ratio}" + inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) + updated_attrs[attr] = inner_tensor + updated_attrs["size"] = new_size + + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.slice.Tensor, + ] +) +@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") +@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") +@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") +def nf4_slice(aten_op, args, kwargs=None): + nf4tensor = args[0] + # for tensor 512 x 512, tensor[:, :512] dispatch to + # aten.slice(dim = 0, end=sys.maxsize) + if not args[3] in [nf4tensor.size(0), sys.maxsize]: + raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") + return NF4Tensor(*construct_nf4_args(nf4tensor)) + +@implements( + [ + aten.view.default, + ] +) +@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +def nf4_view(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update({ + "size": [nf4tensor.numel()], + "stride": (1, ), + }) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.as_strided.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") +def nf4_as_strided(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + stride = tuple(args[2]) + storage_offset = args[3] + if math.prod(size) != nf4tensor.numel(): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") + if stride != make_contiguous_strides_for(size): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") + if nf4tensor.storage_offset() != storage_offset: + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") + kwargs = { + "size": torch.Size(size), + "stride": stride, + "storage_offset": storage_offset, + } + return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) + + +@implements([torch.ops.aten._to_copy.default]) +def _to_copy(func, *args, **kwargs): + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return func(args[0][0].t()).t() + out = args[0][0].get_original_weight().to(args[1]["dtype"]) + if "device" in args[1]: + out = out.to(args[1]["device"]) + return out + + +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, *args, **kwargs): + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() + return args[0][0].get_original_weight().to(args[0][1]) + + +@implements([torch.ops.aten.t.default]) +def t_default(func, *args, **kwargs): + a = args[0][0] + tensor_meta = SubclassTensorArgs( + a.size(), + (a.stride(1), a.stride(0)), + a.storage_offset(), + a.dtype, + a.device, + a.requires_grad, + ) + b = NF4Tensor( + tensor_meta, + a.block_size, + a.n_blocks, + a.scaler_block_size, + a.quantized_scalers, + a.quantization_factor, + a.scaler_mean, + a.quantized_data, + a.nf4, + ) + return b + + +@implements([torch.ops.aten.mm.default]) +def mm_default(func, *args, **kwargs): + return linear_nf4(args[0][0], args[0][1]) + + +@implements( + [ + aten.copy_.default, + ] +) +def copy_(func, *args, **kwargs): + original: NF4Tensor = args[0][0] + copy_in: torch.Tensor = args[0][1] + + # Base Case + + if same_metadata(original, copy_in): + original_tensors = original.__tensor_flatten__()[0] + for tensor_name in original_tensors: + getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) + return + + # Convert Non NF4Tensor into NF4 for copy in + if not isinstance(copy_in, NF4Tensor): + copy_in_nf4 = NF4Tensor.from_tensor( + copy_in, original.block_size, original.scaler_block_size + ) + return original.copy_(copy_in_nf4) + + # Other Tensor is not a NF4Tensor + full_precision = copy_in.get_original_weight() + same_meta_nf4 = NF4Tensor.from_tensor( + full_precision, original.block_size, original.scaler_block_size + ) + return original.copy_(same_meta_nf4) + + +@implements( + [ + aten.is_pinned.default, + ] +) +def nf4_is_pinned(aten_op, args, kwargs=None): + nf4tensor = args[0] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + if not aten_op(inner_tensor, *(args[1:]), **kwargs): + return False + return True + + +@implements( + [ + aten._pin_memory.default, + ] +) +def nf4_pin_memory(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) -@implements_torch_function(torch.Tensor.to) +@implements(torch.Tensor.to) def function_to_dtype(*args, **kwargs): tensor = args[0] if isinstance(args[1], torch.dtype): @@ -925,7 +904,7 @@ def function_to_dtype(*args, **kwargs): ) -@implements_torch_function(torch.Tensor.cpu) +@implements(torch.Tensor.cpu) def function_cpu(*args, **kwargs): nf4tensor = args[0] updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py new file mode 100644 index 0000000000..1e4eb692a5 --- /dev/null +++ b/torchao/dtypes/utils.py @@ -0,0 +1,65 @@ +from typing import Dict, Callable +from collections import defaultdict +import functools + +""" +torch_function and torch_dispatch operator dispatch registrations + +first key is a tensor subclass type like AffineQuantizedTensor, +second key is a `func` in __torhc_function__ or __torch_dispatch__, +value is a function that implements the dispatch +""" +_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Callable, Callable]] = defaultdict(dict) + +def _implements(cls, aten_ops_or_torch_fns): + """Use this decorator to implement a function for an aten ops in __torch_dispatch__ + (if user passed in a list of ops) + or torch function in __torch_function__ (if user passed in a single object) + """ + if not isinstance(aten_ops_or_torch_fns, (list, tuple)): + aten_ops_or_torch_fns = [aten_ops_or_torch_fns] + def decorator(func): + for op in aten_ops_or_torch_fns: + @functools.wraps(op) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper + return func + return decorator + +""" +layout tensor constructor registration for different tensor subclassesa + +first key is a tensor subclass type like AffineQuantizedTensor +second key is an extended layout string, like tensor_core_tiled +value is a constructor for the LayoutTensor class, e.g. TensorCoreTiledAQTLayout.from_plain +""" +_LAYOUT_CONSTRUCTOR_TABLE: Dict[Callable, Dict[str, Callable]] = defaultdict(dict) + +def _register_layout_cls(cls: Callable, extended_layout: str): + """Helper function for layout registrations, this is used to implement + register_layout_cls decorator for each tensor subclass, see aqt.py for example usage + + Args: + cls: Tensor subclass type + extended_layout: string name for the layout type + + Returns: + a decorator that registers the layout tensor constructor in the table + """ + def decorator(layout_cls): + layout_cls.extended_layout = extended_layout + _LAYOUT_CONSTRUCTOR_TABLE[cls][extended_layout] = layout_cls.from_plain + return layout_cls + return decorator + +def _get_layout_tensor_constructor(cls: Callable, extended_layout: str) -> Callable: + """Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `extended_layout` + """ + if cls not in _LAYOUT_CONSTRUCTOR_TABLE: + raise ValueError(f"no registered layout class constructor for: {cls}") + if extended_layout not in _LAYOUT_CONSTRUCTOR_TABLE[cls]: + raise ValueError(f"extended_layout: {extended_layout} is not supported yet for {cls}") + + return _LAYOUT_CONSTRUCTOR_TABLE[cls][extended_layout]