diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index f5c2b0bf63..b4912523bb 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -64,6 +64,12 @@ _linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl, ) +from torchao.dtypes.uintx.q_dq_layout import ( + _embedding_check as _embedding_q_dq_check, +) +from torchao.dtypes.uintx.q_dq_layout import ( + _embedding_impl as _embedding_q_dq_impl, +) from torchao.dtypes.uintx.q_dq_layout import ( _linear_check as _linear_q_dq_check, ) @@ -263,6 +269,9 @@ def _(func, types, args, kwargs): @implements(torch.nn.functional.embedding) def _(func, types, args, kwargs): + if _embedding_q_dq_check(args, kwargs): + return _embedding_q_dq_impl(args, kwargs) + # new_arg1 = args[1].dequantize() # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) assert isinstance( diff --git a/torchao/dtypes/uintx/q_dq_layout.py b/torchao/dtypes/uintx/q_dq_layout.py index d0a58c2e18..1d5b2048b0 100644 --- a/torchao/dtypes/uintx/q_dq_layout.py +++ b/torchao/dtypes/uintx/q_dq_layout.py @@ -50,3 +50,16 @@ def _linear_impl(input_tensor, weight_tensor, bias): if isinstance(weight_tensor, AffineQuantizedTensor): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +def _embedding_check(args, kwargs): + _, weight_tensor = args + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, QDQLayout) + + +def _embedding_impl(args, kwargs): + input_tensor, weight_tensor = args + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.embedding(input_tensor, weight_tensor, **kwargs) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 8091042738..e45a8d2bef 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -15,7 +15,7 @@ quantize_per_channel_group, ) -from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.granularity import Granularity, PerAxis, PerGroup, PerRow from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 logger = logging.getLogger(__name__) @@ -366,32 +366,44 @@ def __init__( ): super().__init__() self.bit_width = bit_width - self.pack_weights_op = getattr( - torch.ops.torchao, f"_pack_embedding_{bit_width}bit" - ) - self.embedding_op = getattr(torch.ops.torchao, f"_embedding_{bit_width}bit") def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros): assert has_weight_zeros, "has_weight_zeros must be True for QuantizedEmbedding" num_embeddings, embedding_dim = weights.shape - if group_size == -1: - group_size = embedding_dim - self.group_size = group_size - weight_qvals, weight_scales, weight_zeros = _quantize( - weights, self.group_size, self.bit_width, has_weight_zeros=True + embedding = torch.nn.Embedding(num_embeddings, embedding_dim) + embedding.weight = weights + quantize_( + embedding, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bit_width}"), + granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), + zero_point_domain=ZeroPointDomain.INT + if has_weight_zeros + else ZeroPointDomain.NONE, + mapping_type=MappingType.ASYMMETRIC, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + weight_qvals, weight_scales, weight_zeros = ( + embedding.weight.tensor_impl.get_plain() ) + weight_scales = weight_scales.reshape(num_embeddings, -1) + weight_zeros = weight_zeros.reshape(num_embeddings, -1).to(torch.int8) self.register_buffer( - "packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8)) + "packed_weight_qvals", + getattr(torch.ops.torchao, f"_pack_embedding_{self.bit_width}bit")( + weight_qvals.to(torch.int8) + ), ) self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.register_buffer("weight_scales", weight_scales) - self.register_buffer("weight_zeros", weight_zeros.to(torch.int8)) + self.register_buffer("weight_zeros", weight_zeros) def forward(self, x): shape = x.shape - return self.embedding_op( + return getattr(torch.ops.torchao, f"_embedding_{self.bit_width}bit")( self.packed_weight_qvals, self.num_embeddings, self.embedding_dim, @@ -410,38 +422,23 @@ def __init__( self.bit_width = bit_width def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros): - assert ( - has_weight_zeros - ), "has_weight_zeros must be True for QuantizedEmbeddingFallback" - num_embeddings, embedding_dim = weights.shape - if group_size == -1: - group_size = embedding_dim - self.group_size = group_size - - weight_qvals, weight_scales, weight_zeros = _quantize( - weights, self.group_size, self.bit_width, has_weight_zeros=True + self.embedding = torch.nn.Embedding(*weights.shape) + self.embedding.weight = weights + quantize_( + self.embedding, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bit_width}"), + granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), + zero_point_domain=ZeroPointDomain.INT + if has_weight_zeros + else ZeroPointDomain.NONE, + mapping_type=MappingType.ASYMMETRIC, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), ) - self.weight_qvals = weight_qvals.to(torch.int32) - self.weight_scales = weight_scales - self.weight_zeros = weight_zeros.to(torch.int32) def forward(self, x): - shape = x.shape - res = [] - for i in x: - res.append( - dequantize_per_channel_group( - w_int8=self.weight_qvals[i, :].reshape(1, -1), - scales=self.weight_scales[i, :].reshape(1, -1), - zero_points=self.weight_zeros[i, :].reshape(1, -1), - quant_min=None, # TODO: why is this an arg for this function - quant_max=None, # TODO: why is this an arg for this function - dtype=None, # TODO: why is this an arg for this function - group_size=self.group_size, - output_dtype=torch.float32, - ).reshape(-1) - ) - return torch.stack(res).reshape(*shape, -1) + return self.embedding(x) class QuantizedSharedEmbedding(nn.Module): @@ -586,7 +583,7 @@ class EmbeddingQuantizer: def __init__( self, weight_dtype: torch.dtype = torch.int4, - granularity: Union[PerRow, PerGroup] = PerRow(), + granularity: Granularity = PerAxis(0), has_weight_zeros: bool = True, use_fallback: bool = False, ): @@ -594,7 +591,8 @@ def __init__( if isinstance(granularity, PerGroup): group_size = granularity.group_size - elif isinstance(granularity, PerRow): + elif isinstance(granularity, PerAxis): + assert granularity.axis == 0 group_size = -1 else: raise ValueError(f"Unsupported granularity: {granularity}") @@ -630,6 +628,7 @@ def quantize(self, model: nn.Module) -> nn.Module: to_linear_activation_quantized, ) from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, MappingType, ZeroPointDomain, to_affine_quantized_intx, diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py index 9a744643c8..1b25dc1371 100644 --- a/torchao/experimental/quant_passes.py +++ b/torchao/experimental/quant_passes.py @@ -215,3 +215,101 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass( # Re-export return torch.export.export(gm, *ep.example_inputs) + + +def _get_q_dq_embedding_patterns_replacements_and_filters( + weight_bit_width, +): + w_quant_min = -(1 << (weight_bit_width - 1)) + w_quant_max = (1 << (weight_bit_width - 1)) - 1 + w_target_dtype = torch.int8 + + def pattern( + indices, + w_int_data, + w_block_size, + w_scale, + w_zero_point, + ): + dq_w = torch.ops.quant.dequantize_affine.default( + w_int_data, + w_block_size, + w_scale, + w_zero_point, + w_target_dtype, + w_quant_min, + w_quant_max, + ) + return torch.ops.aten.embedding.default(dq_w, indices) + + def replacement( + indices, + w_int_data, + w_block_size, + w_scale, + w_zero_point, + ): + num_embeddings, embedding_dim = w_int_data.size() + packed_weight_qvals = getattr( + torch.ops.torchao, f"_pack_embedding_{weight_bit_width}bit" + )(w_int_data) + out_shape = indices.shape + (embedding_dim,) + group_size = w_block_size[-1] + n_groups = embedding_dim // group_size + w_scale = w_scale.reshape(-1, n_groups) + w_zero_point = w_zero_point.reshape(-1, n_groups) + return getattr(torch.ops.torchao, f"_embedding_{weight_bit_width}bit")( + packed_weight_qvals, + num_embeddings, + embedding_dim, + w_scale, + w_zero_point, + indices.reshape(-1), + ).reshape(out_shape) + + def match_filter(match, x, y): + def get_val(name): + node = [n for n in match.nodes_map if n.name == name][0] + return match.nodes_map[node] + + # We only want w_block_size with shape [1, group_size] + w_block_size = get_val("w_block_size") + if len(w_block_size) != 2 or w_block_size[0] != 1: + return False + + return True + + return pattern, replacement, match_filter + + +def replace_q_dq_patterns_with_quantized_embedding_ops_pass( + ep: torch.export.ExportedProgram, +) -> torch.export.ExportedProgram: + """ + This replaces Q/DQ patterns with torchao quantized embedding ops. + It is intended for converting Q/DQ nodes exported with QDQLayout to using + the lowbit quantized embedding ops. + """ + # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) + # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ + assert ( + len(ep.range_constraints) == 0 + ), "ExportedProgram with range constraints are not supported" + + # ep.module() unlifts the weight inputs, which we need for constant folding + gm = ep.module() + for weight_bit_width in range(1, 9): + pattern, replacement, match_filter = ( + _get_q_dq_embedding_patterns_replacements_and_filters( + weight_bit_width, + ) + ) + subgraph_rewriter.replace_pattern_with_filters( + gm, pattern, replacement, match_filters=[match_filter] + ) + + # Constant fold evaluates and removes the packing ops + constant_fold(gm) + + # Re-export + return torch.export.export(gm, *ep.example_inputs) diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 844c96760f..8f4afcda04 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -19,7 +19,7 @@ Int8DynamicActivationIntxWeightConfig, SharedEmbeddingQuantizer, ) -from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.granularity import PerAxis, PerGroup, PerRow from torchao.quantization.quant_api import quantize_ @@ -68,7 +68,7 @@ def test_accuracy(self): def test_export_compile_aoti(self): weight_dtype = torch.int4 - granularity = PerRow() + granularity = PerAxis(0) embedding_dim = 4096 num_embeddings = 131 model = torch.nn.Sequential( @@ -113,7 +113,6 @@ def test_export_compile_aoti(self): def test_shared_embedding(self): weight_dtype = torch.int4 - granularity = PerRow() has_weight_zeros = True embedding_dim = 4096 num_embeddings = 131 @@ -134,14 +133,14 @@ def test_shared_embedding(self): quantized_model_reference = copy.deepcopy(model) EmbeddingQuantizer( weight_dtype=weight_dtype, - granularity=granularity, + granularity=PerAxis(0), has_weight_zeros=has_weight_zeros, ).quantize(quantized_model_reference) quantize_( quantized_model_reference, Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, - granularity=granularity, + granularity=PerRow(), has_weight_zeros=has_weight_zeros, round_weight_scale_to_bf16=False, layout=PackedLinearInt8DynamicActivationIntxWeightLayout( @@ -155,7 +154,7 @@ def test_shared_embedding(self): quantized_model = copy.deepcopy(model) SharedEmbeddingQuantizer( weight_dtype=weight_dtype, - granularity=granularity, + granularity=PerRow(), has_weight_zeros=has_weight_zeros, ).quantize(quantized_model) diff --git a/torchao/experimental/tests/test_quant_passes.py b/torchao/experimental/tests/test_quant_passes.py index 3262e2bf7b..35282f331f 100644 --- a/torchao/experimental/tests/test_quant_passes.py +++ b/torchao/experimental/tests/test_quant_passes.py @@ -7,6 +7,7 @@ import unittest import torch +from parameterized import param, parameterized from torch.testing import FileCheck from torchao.experimental.q_dq_layout import QDQLayout @@ -14,10 +15,16 @@ Int8DynamicActivationIntxWeightConfig, ) from torchao.experimental.quant_passes import ( + replace_q_dq_patterns_with_quantized_embedding_ops_pass, replace_q_dq_patterns_with_quantized_linear_ops_pass, ) -from torchao.quantization.granularity import PerGroup, PerRow -from torchao.quantization.quant_api import quantize_ +from torchao.quantization.granularity import PerAxis, PerGroup, PerRow +from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + MappingType, + ZeroPointDomain, + quantize_, +) class TestQuantPasses(unittest.TestCase): @@ -77,6 +84,65 @@ def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self): exported_results = exported.module()(activations) self.assertTrue(torch.allclose(exported_results, eager_results)) + @parameterized.expand( + [ + param(weight_dtype=weight_dtype, granularity=granularity) + for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)] + for granularity in [PerAxis(0), PerGroup(32)] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_replace_q_dq_patterns_with_quantized_embedding_ops_pass( + self, weight_dtype, granularity + ): + # Calling torch.export many times in a parametrized test causes + # torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached error + # Setting cache_size_limit to a large number to avoid this error + torch._dynamo.config.cache_size_limit = 10000 + + mapping_type = MappingType.ASYMMETRIC + zero_point_domain = ZeroPointDomain.INT + + model = torch.nn.Sequential( + *[torch.nn.Embedding(5000, 512), torch.nn.Linear(512, 512)] + ) + indices = torch.randint(0, 5000, (4, 5, 17), dtype=torch.int32) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + zero_point_domain=zero_point_domain, + mapping_type=mapping_type, + layout=QDQLayout(), + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + eager_results = model(indices) + + exported = torch.export.export(model, (indices,), strict=True) + exported = replace_q_dq_patterns_with_quantized_embedding_ops_pass(exported) + + # We should not find pack op because it gets constant folded + FileCheck().check_not("torch.ops.torchao._pack_embedding").run( + exported.graph_module.code + ) + + # We should find + FileCheck().check_count( + "torch.ops.torchao._embedding", count=1, exactly=True + ).run(exported.graph_module.code) + + # We should not find Q/DQ ops + FileCheck().check_not("torch.ops.quant.dequantize_affine.default").run( + exported.graph_module.code + ) + + # Numerics should match + exported_results = exported.module()(indices) + self.assertTrue(torch.allclose(exported_results, eager_results)) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9bbdd3dfbf..f325a587b4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,6 +36,7 @@ MarlinQQQLayout, MarlinSparseLayout, PlainLayout, + QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, @@ -75,6 +76,9 @@ Int8DynActInt4WeightQuantizer, ) from .granularity import ( + Granularity, + PerAxis, + PerGroup, PerRow, PerTensor, ) @@ -86,6 +90,7 @@ intx_quantization_aware_training, ) from .quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, ) @@ -1569,6 +1574,102 @@ def _uintx_weight_only_transform( return module +@dataclass +class IntxWeightOnlyConfig(AOBaseConfig): + """ + Configuration for quantizing weights to torch.intx, with 1 <= x <= 8. + Weights are quantized with scales and optionally zeros (controlled by zero_point_domain) in a groupwise or channelwise + manner using the number of bits specified by weight_dtype. + args: + weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. + torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 + granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(0). + zero_point_domain: The zero point domain to use for weight quantization. + Must be ZeroPointDomain.INT (if quantized weights have zeros) or ZeroPointDomain.NONE (if quantized weights do not have zeros). + mapping_type: The type of mapping to use for the weight quantization. + Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. + scale_dtype: The dtype to use for the weight scale. + layout: The layout to use for the packed weight tensor: + - QDQLayout: this layout is designed for export to ExecuTorch.this layout represents the quantization with Q/DQ quant primitives, + and is intended for export applications like ExecuTorch. + """ + + weight_dtype: torch.dtype = torch.int8 + granularity: Granularity = PerAxis(0) + zero_point_domain: ZeroPointDomain = ZeroPointDomain.NONE + mapping_type: MappingType = MappingType.SYMMETRIC + scale_dtype: Optional[torch.dtype] = None + layout: Layout = QDQLayout() + + def __post_init__(self): + assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+" + assert ( + self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)] + ), f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" + assert isinstance( + self.granularity, (PerAxis, PerGroup) + ), f"granularity must be PerAxis or PerGroup, but got {self.granularity}" + if isinstance(self.granularity, PerAxis): + assert ( + self.granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {self.granularity.axis}" + assert ( + self.zero_point_domain in [ZeroPointDomain.INT, ZeroPointDomain.NONE] + ), f"zero_point_domain must be ZeroPointDomain.INT or ZeroPointDomain.NONE, but got {self.zero_point_domain}" + assert ( + self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + ), f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" + if self.mapping_type == MappingType.SYMMETRIC: + assert ( + self.zero_point_domain == ZeroPointDomain.NONE + ), f"zero_point_domain must be ZeroPointDomain.NONE when mapping_type is MappingType.SYMMETRIC, but got {self.zero_point_domain}" + + +@register_quantize_module_handler(IntxWeightOnlyConfig) +def _intx_weight_only_transform( + module: torch.nn.Module, config: IntxWeightOnlyConfig +) -> torch.nn.Module: + weight = module.weight + weight_dtype = config.weight_dtype + granularity = config.granularity + zero_point_domain = config.zero_point_domain + mapping_type = config.mapping_type + scale_dtype = config.scale_dtype + layout = config.layout + + assert ( + weight.dim() == 2 + ), f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerAxis): + assert ( + granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {granularity.axis}" + group_size = weight.shape[-1] + else: + raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] + has_weight_zeros = zero_point_domain == ZeroPointDomain.INT + weight = to_affine_quantized_intx( + input_float=weight, + mapping_type=mapping_type, + block_size=(1, group_size), + target_dtype=torch.int8, + quant_min=quant_min, + quant_max=quant_max, + eps=torch.finfo(torch.float32).eps, + scale_dtype=scale_dtype, + zero_point_dtype=torch.int8 if has_weight_zeros else None, + preserve_zero=has_weight_zeros or (mapping_type == MappingType.SYMMETRIC), + zero_point_domain=zero_point_domain, + _layout=layout, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + return module + + @dataclass class FPXWeightOnlyConfig(AOBaseConfig): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits