diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 23a4257497..e9a9ff3f15 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -36,7 +36,7 @@ jobs: # Install executorch first because it installs its own version # of torch and torchao, which we do not want to use pip install executorch - pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall + pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall pip install numpy pip install pytest pip install parameterized diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index d2dea3a5fe..fea0ea1a76 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -35,21 +35,32 @@ class Target(Enum): # AUTO target will automatically select a packing format # based on the available hardware. - # TODO: in future, add the ability to specify specific - # hardware targets AUTO = auto() + UNIVERSAL = auto() + KLEIDIAI = auto() # ATEN target will use the ATen operator ATEN = auto() +_TARGET_AND_STR = [ + (Target.AUTO, "auto"), + (Target.ATEN, "aten"), + (Target.UNIVERSAL, "universal"), + (Target.KLEIDIAI, "kleidiai"), +] + + +def target_to_str(target: Target) -> str: + target_to_str = {t: s for t, s in _TARGET_AND_STR} + return target_to_str[target] + + def target_from_str(target: str) -> Target: - if target.lower() == "auto": - return Target.AUTO - elif target.lower() == "aten": - return Target.ATEN - else: - raise ValueError(f"Invalid target: {target}") + str_to_target = {s: t for t, s in _TARGET_AND_STR} + if target.lower() in str_to_target: + return str_to_target[target.lower()] + raise ValueError(f"Invalid target: {target}") class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): @@ -146,10 +157,9 @@ def from_plain( ): assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" - assert layout.target in { - Target.AUTO, - Target.ATEN, - }, f"Unexpected target: {layout.target}" + assert layout.target in [ + t for t, _ in _TARGET_AND_STR + ], f"Unexpected target: {layout.target}" n, k = int_data.shape if layout.target == Target.ATEN: @@ -174,7 +184,7 @@ def from_plain( zero_point.reshape(-1).to(torch.int8) if layout.has_weight_zeros else None, layout.group_size, bias if layout.has_bias else None, - None, # target, if not passed a packing format will be chosen on C++ side + target_to_str(layout.target) if layout.target != Target.AUTO else None, ] packed_weight = getattr( @@ -223,7 +233,7 @@ def _linear_check(input_tensor, weight_tensor, bias): def _linear_impl(input_tensor, weight_tensor, bias): - def _impl_2d_auto(input_tensor, weight_tensor): + def _impl_2d_non_aten(input_tensor, weight_tensor): assert input_tensor.dim() == 2 assert weight_tensor.dim() == 2 @@ -272,8 +282,8 @@ def _impl_2d_aten(input_tensor, weight_tensor): if target == Target.ATEN: assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0" _impl_2d = _impl_2d_aten - elif target == Target.AUTO: - _impl_2d = _impl_2d_auto + else: + _impl_2d = _impl_2d_non_aten if input_tensor.dim() == 2: res = _impl_2d(input_tensor, weight_tensor) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 2bdc7e52bd..0fa82fecfc 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -6,7 +6,7 @@ import logging import sys -from typing import Optional, Union +from typing import Callable, List, Mapping, Optional, Tuple, Union import torch import torch.nn as nn @@ -28,6 +28,35 @@ logger.addHandler(handler) +def _check_torchao_ops_loaded(): + # Check kernels are installed/loaded + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + except AttributeError: + raise Exception( + "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." + + " You can also set target to 'aten' if you are using ARM CPU." + ) + + +def _dtype_to_bit_width(dtype: torch.dtype) -> int: + dtype_to_bit_width = { + torch.int1: 1, + torch.int2: 2, + torch.int3: 3, + torch.int4: 4, + torch.int5: 5, + torch.int6: 6, + torch.int7: 7, + torch.int8: 8, + } + if dtype not in dtype_to_bit_width: + raise ValueError( + f"dtype must be one of {list(dtype_to_bit_width.keys())}, got {dtype}" + ) + return dtype_to_bit_width[dtype] + + def _quantize( vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, signed=True ): @@ -330,34 +359,43 @@ def quantize(self, model: nn.Module) -> nn.Module: return model -class _IntxWeightQuantizedEmbedding(nn.Module): +class QuantizedEmbedding(nn.Module): def __init__( self, - nbit, - pack_weights_op, - embedding_op, + bit_width, ): super().__init__() - self.nbit = nbit - self._pack_weights_op = pack_weights_op - self._embedding_op = embedding_op + self.bit_width = bit_width + self.pack_weights_op = getattr( + torch.ops.torchao, f"_pack_embedding_{bit_width}bit" + ) + self.embedding_op = getattr(torch.ops.torchao, f"_embedding_{bit_width}bit") - def quantize_and_pack_weights(self, weights, group_size): - self.group_size = group_size + def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros): + assert has_weight_zeros, "has_weight_zeros must be True for QuantizedEmbedding" num_embeddings, embedding_dim = weights.shape + if group_size == -1: + group_size = embedding_dim + self.group_size = group_size weight_qvals, weight_scales, weight_zeros = _quantize( - weights, self.group_size, self.nbit, has_weight_zeros=True + weights, self.group_size, self.bit_width, has_weight_zeros=True ) - self.packed_weight_qvals = self._pack_weights_op(weight_qvals.to(torch.int8)) - self.num_embeddings = torch.empty(0, num_embeddings, dtype=torch.int8) - self.embedding_dim = torch.empty(0, embedding_dim, dtype=torch.int8) - self.weight_scales = weight_scales - self.weight_zeros = weight_zeros.to(torch.int8) + self.register_buffer( + "packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8)) + ) + self.register_buffer( + "num_embeddings", torch.empty(0, num_embeddings, dtype=torch.int8) + ) + self.register_buffer( + "embedding_dim", torch.empty(0, embedding_dim, dtype=torch.int8) + ) + self.register_buffer("weight_scales", weight_scales) + self.register_buffer("weight_zeros", weight_zeros.to(torch.int8)) def forward(self, x): shape = x.shape - return self._embedding_op( + return self.embedding_op( self.packed_weight_qvals, self.num_embeddings, self.embedding_dim, @@ -367,20 +405,25 @@ def forward(self, x): ).reshape(*shape, -1) -class _IntxWeightQuantizedEmbeddingFallback(nn.Module): +class QuantizedEmbeddingFallback(nn.Module): def __init__( self, - nbit, + bit_width, ): super().__init__() - self.nbit = nbit + self.bit_width = bit_width - def quantize_and_pack_weights(self, weights, group_size): - self.group_size = group_size + def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros): + assert ( + has_weight_zeros + ), "has_weight_zeros must be True for QuantizedEmbeddingFallback" num_embeddings, embedding_dim = weights.shape + if group_size == -1: + group_size = embedding_dim + self.group_size = group_size weight_qvals, weight_scales, weight_zeros = _quantize( - weights, self.group_size, self.nbit, has_weight_zeros=True + weights, self.group_size, self.bit_width, has_weight_zeros=True ) self.weight_qvals = weight_qvals.to(torch.int32) self.weight_scales = weight_scales @@ -405,39 +448,95 @@ def forward(self, x): return torch.stack(res).reshape(*shape, -1) -def _replace_embedding_with_quantized_embedding(module: nn.Module, kwargs={}): - group_size = kwargs["group_size"] - nbit = kwargs["nbit"] +class QuantizedSharedEmbedding(nn.Module): + def __init__(self, bit_width, unembedding_packed_weights, group_size, n, k): + super().__init__() + self.bit_width = bit_width + self.register_buffer("unembedding_packed_weights", unembedding_packed_weights) + self.n = n + self.k = k + if group_size == -1: + self.group_size = k + else: + self.group_size = group_size + self.shared_embedding_op = getattr( + torch.ops.torchao, f"_shared_embedding_{bit_width}bit" + ) - assert not isinstance(module, nn.Embedding) - assert nbit >= 1 and nbit <= 8 + def forward(self, x): + shape = x.shape + return self.shared_embedding_op( + self.unembedding_packed_weights, + self.group_size, + self.n, + self.k, + x.reshape(-1), + ).reshape(*shape, -1) + + +def _replace_embedding_with_quantized_embedding( + module: nn.Module, + kwargs={}, + fqn: str = "", +): + group_size = kwargs.get("group_size", None) + bit_width = kwargs.get("bit_width", None) + use_fallback = kwargs.get("use_fallback", None) + has_weight_zeros = kwargs.get("has_weight_zeros", None) + embedding_fqn_to_quantized_unembedding = kwargs.get( + "embedding_fqn_to_quantized_unembedding", None + ) + assert not isinstance(module, nn.Embedding) for name, child in module.named_children(): + child_fqn = f"{fqn}.{name}" if fqn != "" else name + if not isinstance(child, nn.Embedding): - _replace_embedding_with_quantized_embedding(child, kwargs) + _replace_embedding_with_quantized_embedding(child, kwargs, child_fqn) else: - try: - qembedding = _IntxWeightQuantizedEmbedding( - nbit, - getattr(torch.ops.torchao, f"_pack_embedding_{nbit}bit"), - getattr(torch.ops.torchao, f"_embedding_{nbit}bit"), - ) - setattr(module, name, qembedding) - getattr(module, name).quantize_and_pack_weights( - child.weight, group_size - ) - except Exception as e: - logger.warning( - f"_IntxWeightQuantizedEmbedding raised an exception during quantize_and_pack_weights: {e}\n" - + "Falling back to **slow** implementation _IntxWeightQuantizedEmbeddingFallback." - ) - qembedding = _IntxWeightQuantizedEmbeddingFallback(nbit) + assert child.weight.device == torch.device("cpu"), "Only CPU is supported" + assert child.weight.dtype == torch.float32, "Only float32 is supported" + + if use_fallback: + qembedding = QuantizedEmbeddingFallback(bit_width) setattr(module, name, qembedding) getattr(module, name).quantize_and_pack_weights( - child.weight, group_size + child.weight, group_size, has_weight_zeros ) - - + else: + _check_torchao_ops_loaded() + if embedding_fqn_to_quantized_unembedding is None: + qembedding = QuantizedEmbedding(bit_width) + setattr(module, name, qembedding) + getattr(module, name).quantize_and_pack_weights( + child.weight, group_size, has_weight_zeros + ) + else: + if child_fqn not in embedding_fqn_to_quantized_unembedding: + continue + weight_tensor = embedding_fqn_to_quantized_unembedding[child_fqn] + n, k = weight_tensor.shape + group_size = weight_tensor.tensor_impl.get_layout().group_size + packed_weight = weight_tensor.tensor_impl.packed_weight + bit_width = weight_tensor.tensor_impl.get_layout().bit_width + + assert ( + n == child.num_embeddings + ), "num_embeddings must match n in shared_unembedding" + assert ( + k == child.embedding_dim + ), "embedding_dim must match k in shared_unembedding" + qembedding = QuantizedSharedEmbedding( + bit_width, + packed_weight, + group_size, + n, + k, + ) + setattr(module, name, qembedding) + + +# TODO: remove this (needed for BC) class IntxWeightEmbeddingQuantizer: def __init__( self, @@ -479,7 +578,44 @@ def quantize(self, model: nn.Module) -> nn.Module: model, kwargs={ "group_size": self.groupsize, - "nbit": self.bitwidth, + "bit_width": self.bitwidth, + "use_fallback": False, + "has_weight_zeros": True, + }, + ) + return model + + +class EmbeddingQuantizer: + def __init__( + self, + weight_dtype: torch.dtype = torch.int4, + granularity: Union[PerRow, PerGroup] = PerRow(), + has_weight_zeros: bool = True, + use_fallback: bool = False, + ): + bit_width = _dtype_to_bit_width(weight_dtype) + + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerRow): + group_size = -1 + else: + raise ValueError(f"Unsupported granularity: {granularity}") + + self.bit_width = bit_width + self.group_size = group_size + self.use_fallback = use_fallback + self.has_weight_zeros = has_weight_zeros + + def quantize(self, model: nn.Module) -> nn.Module: + _replace_embedding_with_quantized_embedding( + model, + kwargs={ + "group_size": self.group_size, + "bit_width": self.bit_width, + "use_fallback": self.use_fallback, + "has_weight_zeros": self.has_weight_zeros, }, ) return model @@ -553,21 +689,7 @@ def _int8_dynamic_activation_intx_weight_transform( act_mapping_type = config.act_mapping_type layout = config.layout - dtype_to_bit_width = { - torch.int1: 1, - torch.int2: 2, - torch.int3: 3, - torch.int4: 4, - torch.int5: 5, - torch.int6: 6, - torch.int7: 7, - torch.int8: 8, - } - if weight_dtype not in dtype_to_bit_width: - raise ValueError( - f"weight_dtype must be one of {list(dtype_to_bit_width.keys())}, got {weight_dtype}" - ) - bit_width = dtype_to_bit_width[weight_dtype] + bit_width = _dtype_to_bit_width(weight_dtype) if isinstance(granularity, PerGroup): group_size = granularity.group_size @@ -602,14 +724,7 @@ def _int8_dynamic_activation_intx_weight_transform( tensor_impl_ctr_kwargs = {"bias": bias} if layout.target == Target.AUTO: - # Check kernels are installed/loaded - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except AttributeError: - raise Exception( - "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." - + " You can also set target to 'aten' if you are using ARM CPU." - ) + _check_torchao_ops_loaded() elif layout.target == Target.ATEN: # TODO: long term, we want to disfavor this route for using KleidiAI in torchao # KleidiAI kernels are accessible via Target.AUTO if torchao is built @@ -681,6 +796,189 @@ def _int8_dynamic_activation_intx_weight_transform( return module +from torchao.quantization.quant_api import quantize_ + + +def _get_fqns_with_filter( + module: nn.Module, + filter_fn: Callable[Tuple[str, nn.Module], bool], + fqn: str, + fqns: List[str], +): + for name, child in module.named_children(): + child_fqn = f"{fqn}.{name}" if fqn != "" else name + if filter_fn(child, child_fqn): + fqns.append(child_fqn) + else: + _get_fqns_with_filter(child, filter_fn, child_fqn, fqns) + + +def get_fqns_with_filter( + module: nn.Module, filter_fn: Callable[Tuple[str, nn.Module], bool] +) -> List[str]: + fqns = [] + _get_fqns_with_filter(module, filter_fn, "", fqns) + return fqns + + +class QuantizedLinear(nn.Module): + def __init__(self, weight, bias): + super().__init__() + self.n, self.k = weight.shape + self.group_size = weight.tensor_impl.get_layout().group_size + self.bit_width = weight.tensor_impl.get_layout().bit_width + self.register_buffer("packed_weight", weight.tensor_impl.packed_weight) + self.bias = bias + + def _forward_2d(self, x): + assert x.dim() == 2 + m, k = x.shape + assert k == self.k + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{self.bit_width}bit_weight" + )(x, self.packed_weight, self.group_size, self.n, self.k) + + def forward(self, x): + if x.dim() == 2: + res = self._forward_2d(x) + else: + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + assert k == self.k + res = self._forward_2d(x.reshape(-1, k)) + res = res.reshape(*lead_shape, m, self.n) + + if self.bias is not None: + res = res + self.bias + return res + + +from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor + + +def replace_linear_tensor_subclass_with_module(module: nn.Module): + assert not isinstance(module, nn.Linear) + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + replace_linear_tensor_subclass_with_module(child) + else: + if not isinstance(child.weight, AffineQuantizedTensor): + continue + if not isinstance( + child.weight.tensor_impl.get_layout(), + PackedLinearInt8DynamicActivationIntxWeightLayout, + ): + continue + if child.weight.tensor_impl.get_layout().target == Target.ATEN: + continue + setattr(module, name, QuantizedLinear(child.weight, child.bias)) + + +class SharedEmbeddingQuantizer: + def __init__( + self, + weight_dtype: torch.dtype = torch.int4, + granularity: Union[PerRow, PerGroup] = PerRow(), + has_weight_zeros: bool = True, + ): + self.weight_dtype = weight_dtype + self.granularity = granularity + self.has_weight_zeros = has_weight_zeros + + def quantize( + self, + model: nn.Module, + embedding_to_unembedding: Optional[Mapping[str, str]] = None, + ): + embedding_fqns = get_fqns_with_filter( + model, lambda m, fqn: isinstance(m, nn.Embedding) + ) + linear_fqns = get_fqns_with_filter( + model, lambda m, fqn: isinstance(m, nn.Linear) + ) + state_dict = model.state_dict() + + # If embedding_to_unembedding is not provided, automatically detect shared embeddings and unembeddings + if embedding_to_unembedding is None: + embedding_to_unembedding = {} + for embedding_fqn in embedding_fqns: + embedding_w = state_dict[embedding_fqn + ".weight"] + for linear_fqn in linear_fqns: + linear_w = state_dict[linear_fqn + ".weight"] + if embedding_w.shape == linear_w.shape and torch.allclose( + embedding_w, linear_w + ): + print( + f"Found shared embedding {embedding_fqn} and unembedding {linear_fqn}" + ) + if embedding_fqn not in embedding_to_unembedding: + embedding_to_unembedding[embedding_fqn] = linear_fqn + else: + raise ValueError( + f"Found multiple candidate unembeddings ({embedding_to_unembedding[embedding_fqn]}, {linear_fqn}) for embedding {embedding_fqn}. This is not supported yet. Please explicitly define the input embedding_to_unembedding." + ) + + # Construct reverse mapping + unembedding_to_embedding = {} + for v, k in embedding_to_unembedding.items(): + if k not in unembedding_to_embedding: + unembedding_to_embedding[k] = v + else: + raise ValueError( + f"Found multiple candidate embeddings ({unembedding_to_embedding[k]}, {v}) for unembedding {k}. This is not supported yet." + ) + + # Check that embeddings are shared, embeddings are embeddings, and unembeddings are linear ops + for embedding_fqn, unembedding_fqn in embedding_to_unembedding.items(): + assert ( + embedding_fqn in embedding_fqns + ), f"Embedding {embedding_fqn} is not found in model" + assert ( + unembedding_fqn in linear_fqns + ), f"Unembedding {unembedding_fqn} is not found in model" + assert torch.allclose( + state_dict[embedding_fqn + ".weight"], + state_dict[unembedding_fqn + ".weight"], + ), f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" + + # Quantize unembeddings + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=self.weight_dtype, + granularity=self.granularity, + has_weight_zeros=self.has_weight_zeros, + # Only universal layout is supported for shared embedding + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="universal" + ), + ), + filter_fn=lambda m, fqn: isinstance(m, nn.Linear) + and fqn in list(embedding_to_unembedding.values()), + ) + + embedding_fqn_to_quantized_unembedding = {} + for fqn, t in model.state_dict().items(): + if ( + fqn.endswith(".weight") + and fqn[: -len(".weight")] in unembedding_to_embedding + ): + embedding_fqn = unembedding_to_embedding[fqn[: -len(".weight")]] + embedding_fqn_to_quantized_unembedding[embedding_fqn] = t + + _replace_embedding_with_quantized_embedding( + model, + kwargs={ + "embedding_fqn_to_quantized_unembedding": embedding_fqn_to_quantized_unembedding, + }, + ) + + # Remove subclasses. Otherwise there are two packed_weight objects in exported model, + # even though they have the same id in eager mode + replace_linear_tensor_subclass_with_module(model) + + class UIntxWeightOnlyQuantizedLinear(nn.Module): def __init__( self, diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 40bfc6f53e..403bc900b4 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -9,16 +9,23 @@ import unittest import torch +from torch.testing import FileCheck +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) from torchao.experimental.quant_api import ( - IntxWeightEmbeddingQuantizer, - _IntxWeightQuantizedEmbeddingFallback, + EmbeddingQuantizer, + Int8DynamicActivationIntxWeightConfig, + SharedEmbeddingQuantizer, ) +from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.quant_api import quantize_ class TestEmbeddingQuantizer(unittest.TestCase): def test_accuracy(self): - group_size = 128 + granularity = PerGroup(128) embedding_dim = 4096 num_embeddings = 131 model = torch.nn.Sequential( @@ -26,27 +33,42 @@ def test_accuracy(self): ) indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32) - for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: - print(f"Testing nbit={nbit}") + for weight_dtype in [ + torch.int1, + torch.int2, + torch.int3, + torch.int4, + torch.int5, + torch.int6, + torch.int7, + torch.int8, + ]: + print(f"Testing weight_dtype={weight_dtype}") quantized_model = copy.deepcopy(model) - quantizer = IntxWeightEmbeddingQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=nbit, - groupsize=group_size, + quantizer = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=True, + use_fallback=False, ) quantized_model = quantizer.quantize(quantized_model) with torch.no_grad(): + reference_quantizer = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=True, + use_fallback=True, + ) + reference_model = copy.deepcopy(model) + reference_model = reference_quantizer.quantize(reference_model) result = quantized_model(indices) - reference_impl = _IntxWeightQuantizedEmbeddingFallback(nbit) - reference_impl.quantize_and_pack_weights(model[0].weight, group_size) - expected_result = reference_impl(indices) + expected_result = reference_model(indices) self.assertTrue(torch.allclose(result, expected_result)) def test_export_compile_aoti(self): - nbit = 4 - group_size = 128 + weight_dtype = torch.int4 + granularity = PerRow() embedding_dim = 4096 num_embeddings = 131 model = torch.nn.Sequential( @@ -55,33 +77,112 @@ def test_export_compile_aoti(self): indices = torch.randint(0, num_embeddings, (42,), dtype=torch.int32) print("Quantizing model") - quantizer = IntxWeightEmbeddingQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=nbit, - groupsize=group_size, + quantizer = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=True, + use_fallback=False, ) quantized_model = quantizer.quantize(model) + eager_results = model(indices) print("Exporting quantized model") - torch.export.export(quantized_model, (indices,), strict=True) + with torch.no_grad(): + exported_model = torch.export.export( + quantized_model, (indices,), strict=True + ) + exported_results = exported_model.module()(indices) + self.assertTrue(torch.allclose(eager_results, exported_results)) print("Compiling quantized model") quantized_model_compiled = torch.compile(quantized_model) with torch.no_grad(): quantized_model_compiled(indices) + compiled_results = quantized_model_compiled(indices) + self.assertTrue(torch.allclose(eager_results, compiled_results)) with tempfile.TemporaryDirectory() as tmpdirname: print("Exporting quantized model with AOTI") - torch._export.aot_compile( - quantized_model, - (indices,), - options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + package_path = f"{tmpdirname}/model.pt2" + torch._inductor.aoti_compile_and_package( + exported_model, package_path=package_path ) + fn = torch._inductor.aoti_load_package(package_path) + aoti_results = fn(indices) + self.assertTrue(torch.allclose(eager_results, aoti_results)) + + def test_shared_embedding(self): + weight_dtype = torch.int4 + granularity = PerRow() + has_weight_zeros = True + embedding_dim = 4096 + num_embeddings = 131 + embedding = torch.nn.Embedding(num_embeddings, embedding_dim) + unembedding = torch.nn.Linear(embedding_dim, num_embeddings) + unembedding.weight = copy.deepcopy(embedding.weight) + model = torch.nn.Sequential( + *[ + embedding, + torch.nn.Linear(embedding_dim, embedding_dim), + unembedding, + ] + ) + indices = torch.randint(0, num_embeddings, (42,), dtype=torch.int32) - print("Running quantized model in AOTI") - fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") - fn(indices) + # Reference implementation quantizes the embedding and unembedding + # layers separately + quantized_model_reference = copy.deepcopy(model) + EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + ).quantize(quantized_model_reference) + quantize_( + quantized_model_reference, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="universal" + ), + ), + filter_fn=lambda m, fqn: fqn == "2", + ) + + # Do shared embedding quantization + quantized_model = copy.deepcopy(model) + SharedEmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + ).quantize(quantized_model) + + # Check results are same and weights share the same id + with torch.no_grad(): + result = quantized_model(indices) + expected_result = quantized_model_reference(indices) + self.assertTrue(torch.allclose(result, expected_result)) + self.assertTrue( + id(quantized_model[0].unembedding_packed_weights) + == id(quantized_model[2].packed_weight) + ) + + # Test export + exported_program = torch.export.export(quantized_model, (indices,)) + exported_result = exported_program.module()(indices) + self.assertTrue(torch.allclose(result, exported_result)) + + # Check the shared_embedding and linear ops use the same lifted weight + weight = "b_getattr_l__fn_____0___unembedding_packed_weights" + expected_lines = [ + f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)", + f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)", + ] + for line in expected_lines: + FileCheck().check_count(line, 1, exactly=True).run( + exported_program.graph_module.code + ) if __name__ == "__main__":