From 042816482d172804ff35ec1191a4e19b43c1604c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 22 May 2025 20:04:06 -0700 Subject: [PATCH 01/12] Add support for fbgemm int4 mm kernel Summary: we also plan to expose some other kernels like fp8xint4 and bf16xfp8, fp8xfp8 to compare with existing torchao kernels Test Plan: test/dtypes/test_fbgemm_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_fbgemm_int4_tensor.py | 27 ++++ torchao/_models/llama/generate.py | 11 ++ torchao/dtypes/__init__.py | 2 + torchao/dtypes/fbgemm_int4_tensor.py | 164 +++++++++++++++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 29 +++++ 6 files changed, 235 insertions(+) create mode 100644 test/dtypes/test_fbgemm_int4_tensor.py create mode 100644 torchao/dtypes/fbgemm_int4_tensor.py diff --git a/test/dtypes/test_fbgemm_int4_tensor.py b/test/dtypes/test_fbgemm_int4_tensor.py new file mode 100644 index 0000000000..2b62fd76b5 --- /dev/null +++ b/test/dtypes/test_fbgemm_int4_tensor.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + FbgemmConfig, + quantize_, +) + + +class TestFbgemmInt4Tensor(TestCase): + def test_linear(self): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + config = FbgemmConfig(io_dtype="bf16i4bf16", is_grouped_mm=False) + quantize_(linear, config) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index dc03204b46..4f7c985912 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -439,6 +439,17 @@ def ffn_or_attn_only(mod, fqn): f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" ) quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) + elif "fbgemm" in quantization: + from torchao.quantization import FbgemmConfig + + _, precision, group_size = quantization.split("-") + group_size = int(group_size) + if precision == "int4": + quantize_(model, FbgemmConfig("bf16i4bf16", group_size)) + else: + raise NotImplementedError( + f"FbegemmConfig({precision=}) not supported yet" + ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index eb253c11bc..9a9b0d8fcf 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,6 +8,7 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) +from .fbgemm_int4_tensor import to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -61,4 +62,5 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", + "to_fbgemm_int4", ] diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py new file mode 100644 index 0000000000..658d1a5908 --- /dev/null +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "to_fbgemm_int4", +] + +aten = torch.ops.aten + + +def int4_row_quantize( + x: torch.Tensor, + group_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + n_bit = 4 # Number of target bits. + to_quant = x.reshape(-1, group_size).to(torch.float) + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + + zeros = min_val + scales * (2 ** (n_bit - 1)) + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + + # Recenter output and move to int8. + out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape) + + # Cutlass expects column major layout for scale and zero point, + # so we transpose here and make them contiguous. + scales = scales.view(x.shape[0], -1).t().contiguous() + zeros = zeros.view(x.shape[0], -1).t().contiguous() + + return out, scales, zeros + + +def pack_int4(x: torch.Tensor) -> torch.Tensor: + # Given int8 x, pack adjacent int4 values into a single int8. + low_x = x[:, ::2] + high_x = x[:, 1::2] + + # High bits need to left shift, this also masks off extra bits. + high_x = torch.bitwise_left_shift(high_x, 4) + # Low bits need to have sign bits removed. + low_x = torch.bitwise_and(low_x, 0xF) + + # Recombine into a single value with bitwise or. + return torch.bitwise_or(low_x, high_x).contiguous() + + +class FbgemmInt4Tensor(TorchAOBaseTensor): + tensor_data_attrs = ["packed_weight", "scale", "zero_point"] + tensor_attributes = ["group_size"] + + def __new__(cls, packed_weight, scale, zero_point, group_size): + shape = packed_weight.shape + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, packed_weight, scale, zero_point, group_size): + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + + def __tensor_flatten__(self): + return self.tensor_data_attrs, [ + getattr(self, attr) for attr in self.tensor_attributes + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensor_attributes, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *[getattr(self, attr) for attr in self.tensor_attributes], + ) + + def __repr__(self): + raise NotImplementedError("Subclasses must implement __repr__") + + @classmethod + def from_float(cls, w: torch.Tensor, group_size: int = 128): + if w.ndim >= 3: + wq, scale, zero_point = zip( + *[int4_row_quantize(i, group_size) for i in w], strict=False + ) + wq = torch.stack([pack_int4(i) for i in wq], dim=0) + scale = torch.stack(scale, dim=0) + zero_point = torch.stack(zero_point, dim=0) + else: + wq, scale, zero_point = int4_row_quantize(w, group_size) + wq = pack_int4(wq) + del w + return FbgemmInt4Tensor( + packed_weight=wq, + scale=scale, + zero_point=zero_point, + group_size=group_size, + ) + + +implements = FbgemmInt4Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( + input_tensor, + weight_tensor.packed_weight, + weight_tensor.scale, + weight_tensor.zero_point, + ) + if bias is not None: + res = res + bias + return res + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements([aten.clone.default, aten.copy_.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +to_fbgemm_int4 = FbgemmInt4Tensor.from_float diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b4d46d8263..73ccd2e0ff 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -40,6 +40,7 @@ ) from .quant_api import ( CutlassInt4PackedLayout, + FbgemmConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, @@ -148,6 +149,7 @@ "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", "ModuleFqnToConfig", + "FbgemmConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f2aca97782..c105965b23 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -45,6 +45,7 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, + to_fbgemm_int4, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -142,6 +143,7 @@ "Int8DynActInt4WeightGPTQQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", + "FbgemmConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -1967,6 +1969,33 @@ def _fpx_weight_only_transform( return module +@dataclass +class FbgemmConfig(AOBaseConfig): + io_dtype: str + group_size: int = 128 + is_grouped_mm: bool = False + + +@register_quantize_module_handler(FbgemmConfig) +def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: + if config.io_dtype == "bf16i4bf16" and not config.is_grouped_mm: + try: + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + + logger.info("Using efficient FP8 or INT4 operators in fbgemm-gpu.") + except ImportError: + logger.error( + "No efficient FP8 or INT4 operators. Please install fbgemm-gpu." + ) + raise + + weight = to_fbgemm_int4(module.weight, config.group_size) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + raise NotImplementedError(f"{config} not supported yet") + + @dataclass class ModuleFqnToConfig(AOBaseConfig): """Per module configurations for torchao quantize_ API From 8b59bbae17efc6fcd9327442105f9682cbc42aeb Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 23 May 2025 16:30:59 -0700 Subject: [PATCH 02/12] fix and test --- test/dtypes/test_fbgemm_int4_tensor.py | 14 +++++++++++++- torchao/dtypes/fbgemm_int4_tensor.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_fbgemm_int4_tensor.py b/test/dtypes/test_fbgemm_int4_tensor.py index 2b62fd76b5..9e0bd4fe86 100644 --- a/test/dtypes/test_fbgemm_int4_tensor.py +++ b/test/dtypes/test_fbgemm_int4_tensor.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import unittest + import torch from torch.testing._internal.common_utils import ( TestCase, @@ -14,13 +16,23 @@ FbgemmConfig, quantize_, ) +from torchao.quantization.utils import compute_error +from torchao.utils import is_sm_at_least_90 class TestFbgemmInt4Tensor(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") def test_linear(self): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) config = FbgemmConfig(io_dtype="bf16i4bf16", is_grouped_mm=False) quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) if __name__ == "__main__": diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index 658d1a5908..db71701bde 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -136,7 +136,7 @@ def _(func, types, args, kwargs): f"{func} is not implemented for non floating point input" ) - res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( + res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, weight_tensor.packed_weight, weight_tensor.scale, From 2e63833e72fb3c004a1f4ec4b191f1a72ca98fef Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 24 May 2025 15:44:23 -0700 Subject: [PATCH 03/12] fix dtype --- torchao/_models/llama/generate.py | 2 +- torchao/dtypes/fbgemm_int4_tensor.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 4f7c985912..c17de52028 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -1174,7 +1174,7 @@ def callback(x): help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx--, uintx---hqq, sparse-marlin, spinquant, " - + "embed-int8wo, marlin_qqq, gemlite---, float8dq, int4dq-" + + "embed-int8wo, marlin_qqq, gemlite---, float8dq, int4dq-, fbgemm-int4-" ), ) parser.add_argument( diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index db71701bde..bfc5c777f5 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -43,7 +43,7 @@ def int4_row_quantize( scales = scales.view(x.shape[0], -1).t().contiguous() zeros = zeros.view(x.shape[0], -1).t().contiguous() - return out, scales, zeros + return out, scales.to(x.dtype), zeros.to(x.dtype) def pack_int4(x: torch.Tensor) -> torch.Tensor: @@ -68,6 +68,7 @@ def __new__(cls, packed_weight, scale, zero_point, group_size): shape = packed_weight.shape kwargs = {} kwargs["device"] = packed_weight.device + kwargs["dtype"] = scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] @@ -98,7 +99,10 @@ def _apply_fn_to_data(self, fn): ) def __repr__(self): - raise NotImplementedError("Subclasses must implement __repr__") + return ( + f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " + f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) @classmethod def from_float(cls, w: torch.Tensor, group_size: int = 128): @@ -136,6 +140,9 @@ def _(func, types, args, kwargs): f"{func} is not implemented for non floating point input" ) + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, weight_tensor.packed_weight, @@ -144,7 +151,7 @@ def _(func, types, args, kwargs): ) if bias is not None: res = res + bias - return res + return res.reshape(*orig_act_size[:-1], orig_out_features) @implements([aten.detach.default, aten.alias.default]) From bdf19aa5b6715c6545edbf0a66b5b19b1e5db6d9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 16:37:54 -0700 Subject: [PATCH 04/12] use importlib --- test/dtypes/test_fbgemm_int4_tensor.py | 2 +- torchao/quantization/quant_api.py | 30 ++++++++++++++++---------- torchao/utils.py | 10 +++++++++ 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/test/dtypes/test_fbgemm_int4_tensor.py b/test/dtypes/test_fbgemm_int4_tensor.py index 9e0bd4fe86..ae0f1a36f0 100644 --- a/test/dtypes/test_fbgemm_int4_tensor.py +++ b/test/dtypes/test_fbgemm_int4_tensor.py @@ -29,7 +29,7 @@ def test_linear(self): input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - config = FbgemmConfig(io_dtype="bf16i4bf16", is_grouped_mm=False) + config = FbgemmConfig(io_dtype="bf16i4bf16") quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c105965b23..f297ee90cd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -15,6 +15,7 @@ and mixed GEMM kernels """ +import importlib.util import logging import types import warnings @@ -1971,24 +1972,31 @@ def _fpx_weight_only_transform( @dataclass class FbgemmConfig(AOBaseConfig): + """Quantization Config for fbgemm-genai kernels + Args: + io_dtype (str): The input output dtype of the input, weight and output of the kernel, + for example: bf16i4bf16 means input is bf16, weight is int4 and output is bf16. + Currently available options are ["bf16i4bf16"] + group_size (int): The group size for weight + """ + io_dtype: str group_size: int = 128 - is_grouped_mm: bool = False @register_quantize_module_handler(FbgemmConfig) def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: - if config.io_dtype == "bf16i4bf16" and not config.is_grouped_mm: - try: - import fbgemm_gpu.experimental.gen_ai # noqa: F401 - - logger.info("Using efficient FP8 or INT4 operators in fbgemm-gpu.") - except ImportError: - logger.error( - "No efficient FP8 or INT4 operators. Please install fbgemm-gpu." - ) - raise + # TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when + # https://github.com/pytorch/FBGEMM/issues/4198 is fixed + if importlib.util.find_spec("fbgemm_gpu") is None: + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + + if fbgemm_gpu.__version__ < "1.2.0": + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + if config.io_dtype == "bf16i4bf16": weight = to_fbgemm_int4(module.weight, config.group_size) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) diff --git a/torchao/utils.py b/torchao/utils.py index 280da4e632..1fa395cb8a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import functools +import importlib import itertools import re import time @@ -40,6 +41,7 @@ "is_MI300", "is_sm_at_least_89", "is_sm_at_least_90", + "is_package_at_least", ] @@ -694,3 +696,11 @@ def check_xpu_version(device, version="2.8.0"): TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") + + +def is_package_at_least(package_name: str, min_version: str): + package_exists = importlib.util.find_spec(package_name) is not None + if not package_exists: + return False + + return version(package_name) >= min_version From 82b1958ac37d8a4a79a1b2d55a2dc9d387dcbc25 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 16:51:45 -0700 Subject: [PATCH 05/12] add links to fbgemm code --- torchao/dtypes/fbgemm_int4_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index bfc5c777f5..d7eb2e4d04 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -18,6 +18,7 @@ aten = torch.ops.aten +# copied from https://github.com/pytorch/FBGEMM/blob/2bf4d9aa739b3e78362ca801a72dacb16c67346f/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L60 def int4_row_quantize( x: torch.Tensor, group_size: int = 128, @@ -46,6 +47,7 @@ def int4_row_quantize( return out, scales.to(x.dtype), zeros.to(x.dtype) +# copied from https://github.com/pytorch/FBGEMM/blob/2bf4d9aa739b3e78362ca801a72dacb16c67346f/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L18 def pack_int4(x: torch.Tensor) -> torch.Tensor: # Given int8 x, pack adjacent int4 values into a single int8. low_x = x[:, ::2] From e69b30266ef2a8bcb726ac23e3876eb804fb1509 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 17:55:54 -0700 Subject: [PATCH 06/12] update io_dtype type --- test/dtypes/test_fbgemm_int4_tensor.py | 3 +- torchao/dtypes/fbgemm_int4_tensor.py | 58 ++++++-------------------- torchao/quantization/quant_api.py | 11 ++++- 3 files changed, 24 insertions(+), 48 deletions(-) diff --git a/test/dtypes/test_fbgemm_int4_tensor.py b/test/dtypes/test_fbgemm_int4_tensor.py index ae0f1a36f0..51d996285f 100644 --- a/test/dtypes/test_fbgemm_int4_tensor.py +++ b/test/dtypes/test_fbgemm_int4_tensor.py @@ -29,7 +29,8 @@ def test_linear(self): input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - config = FbgemmConfig(io_dtype="bf16i4bf16") + # config = FbgemmConfig(io_dtype="bf16i4bf16") + config = FbgemmConfig(io_dtype="bf16i8bf16") quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index d7eb2e4d04..a3f6be62ff 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple + +import importlib.util import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -18,48 +19,11 @@ aten = torch.ops.aten -# copied from https://github.com/pytorch/FBGEMM/blob/2bf4d9aa739b3e78362ca801a72dacb16c67346f/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L60 -def int4_row_quantize( - x: torch.Tensor, - group_size: int = 128, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - n_bit = 4 # Number of target bits. - to_quant = x.reshape(-1, group_size).to(torch.float) - - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2**n_bit - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-6) / max_int - - zeros = min_val + scales * (2 ** (n_bit - 1)) - - out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) - - # Recenter output and move to int8. - out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape) - - # Cutlass expects column major layout for scale and zero point, - # so we transpose here and make them contiguous. - scales = scales.view(x.shape[0], -1).t().contiguous() - zeros = zeros.view(x.shape[0], -1).t().contiguous() - - return out, scales.to(x.dtype), zeros.to(x.dtype) - - -# copied from https://github.com/pytorch/FBGEMM/blob/2bf4d9aa739b3e78362ca801a72dacb16c67346f/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L18 -def pack_int4(x: torch.Tensor) -> torch.Tensor: - # Given int8 x, pack adjacent int4 values into a single int8. - low_x = x[:, ::2] - high_x = x[:, 1::2] - - # High bits need to left shift, this also masks off extra bits. - high_x = torch.bitwise_left_shift(high_x, 4) - # Low bits need to have sign bits removed. - low_x = torch.bitwise_and(low_x, 0xF) - - # Recombine into a single value with bitwise or. - return torch.bitwise_or(low_x, high_x).contiguous() +if importlib.util.find_spec("fbgemm_gpu") is None: + int4_row_quantize_zp = None + pack_int4 = None +else: + from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 class FbgemmInt4Tensor(TorchAOBaseTensor): @@ -110,14 +74,18 @@ def __repr__(self): def from_float(cls, w: torch.Tensor, group_size: int = 128): if w.ndim >= 3: wq, scale, zero_point = zip( - *[int4_row_quantize(i, group_size) for i in w], strict=False + *[int4_row_quantize_zp(i, group_size) for i in w], strict=False ) wq = torch.stack([pack_int4(i) for i in wq], dim=0) scale = torch.stack(scale, dim=0) zero_point = torch.stack(zero_point, dim=0) else: - wq, scale, zero_point = int4_row_quantize(w, group_size) + wq, scale, zero_point = int4_row_quantize_zp(w, group_size) wq = pack_int4(wq) + + scale = scale.to(w.dtype) + zero_point = zero_point.to(w.dtype) + del w return FbgemmInt4Tensor( packed_weight=wq, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f297ee90cd..cb887fd421 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,6 +20,7 @@ import types import warnings from dataclasses import dataclass, field +from enum import Enum from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -1970,6 +1971,10 @@ def _fpx_weight_only_transform( return module +class FbgemmKernelIODtype(str, Enum): + bf16i4bf16 = "bf16i4bf16" + + @dataclass class FbgemmConfig(AOBaseConfig): """Quantization Config for fbgemm-genai kernels @@ -1980,7 +1985,7 @@ class FbgemmConfig(AOBaseConfig): group_size (int): The group size for weight """ - io_dtype: str + io_dtype: FbgemmKernelIODtype group_size: int = 128 @@ -2001,7 +2006,9 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) else: - raise NotImplementedError(f"{config} not supported yet") + raise NotImplementedError( + f"{config} is not supported. supported io_dtypes are: {list(FbgemmKernelIODtype)}" + ) @dataclass From 82217a24eee78a02571f7b0c6af812bb060f3b16 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 18:35:11 -0700 Subject: [PATCH 07/12] renaming --- ...sor.py => test_fbgemm_quantized_tensor.py} | 8 +++-- torchao/dtypes/__init__.py | 4 +-- ...4_tensor.py => fbgemm_quantized_tensor.py} | 26 ++++++++++++-- torchao/quantization/quant_api.py | 34 ++++++++++++++----- 4 files changed, 56 insertions(+), 16 deletions(-) rename test/dtypes/{test_fbgemm_int4_tensor.py => test_fbgemm_quantized_tensor.py} (85%) rename torchao/dtypes/{fbgemm_int4_tensor.py => fbgemm_quantized_tensor.py} (84%) diff --git a/test/dtypes/test_fbgemm_int4_tensor.py b/test/dtypes/test_fbgemm_quantized_tensor.py similarity index 85% rename from test/dtypes/test_fbgemm_int4_tensor.py rename to test/dtypes/test_fbgemm_quantized_tensor.py index 51d996285f..fe2573530c 100644 --- a/test/dtypes/test_fbgemm_int4_tensor.py +++ b/test/dtypes/test_fbgemm_quantized_tensor.py @@ -29,8 +29,12 @@ def test_linear(self): input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - # config = FbgemmConfig(io_dtype="bf16i4bf16") - config = FbgemmConfig(io_dtype="bf16i8bf16") + config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=(1, 128), + ) quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9a9b0d8fcf..1003491828 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,7 +8,7 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .fbgemm_int4_tensor import to_fbgemm_int4 +from .fbgemm_quantized_tensor import to_fbgemm_quantized from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -62,5 +62,5 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", - "to_fbgemm_int4", + "to_fbgemm_quantized", ] diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_quantized_tensor.py similarity index 84% rename from torchao/dtypes/fbgemm_int4_tensor.py rename to torchao/dtypes/fbgemm_quantized_tensor.py index a3f6be62ff..20fbd0058d 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_quantized_tensor.py @@ -6,6 +6,7 @@ import importlib.util +from typing import Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -13,7 +14,7 @@ from torchao.utils import TorchAOBaseTensor __all__ = [ - "to_fbgemm_int4", + "to_fbgemm_quantized", ] aten = torch.ops.aten @@ -71,7 +72,25 @@ def __repr__(self): ) @classmethod - def from_float(cls, w: torch.Tensor, group_size: int = 128): + def from_float( + cls, + w: torch.Tensor, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + output_dtype: torch.dtype, + block_size: Tuple[int], + ): + assert len(block_size) == w.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" + ) + group_size = block_size[-1] + + assert (input_dtype, weight_dtype, output_dtype) == ( + torch.bfloat16, + torch.int4, + torch.bfloat16, + ) + if w.ndim >= 3: wq, scale, zero_point = zip( *[int4_row_quantize_zp(i, group_size) for i in w], strict=False @@ -138,4 +157,5 @@ def _(func, types, args, kwargs): ) -to_fbgemm_int4 = FbgemmInt4Tensor.from_float +# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later +to_fbgemm_quantized = FbgemmInt4Tensor.from_float diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index cb887fd421..20d87e9ff9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -47,7 +47,7 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, - to_fbgemm_int4, + to_fbgemm_quantized, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -1979,14 +1979,16 @@ class FbgemmKernelIODtype(str, Enum): class FbgemmConfig(AOBaseConfig): """Quantization Config for fbgemm-genai kernels Args: - io_dtype (str): The input output dtype of the input, weight and output of the kernel, - for example: bf16i4bf16 means input is bf16, weight is int4 and output is bf16. - Currently available options are ["bf16i4bf16"] + input_dtype (torch.dtype): input dtype of the kernel + weight_dtype (torch.dtype): weight dtype of the kernel + output_dtype (torch.dtype): output dtype of the kernel group_size (int): The group size for weight """ - io_dtype: FbgemmKernelIODtype - group_size: int = 128 + input_dtype: torch.dtype + weight_dtype: torch.dtype + output_dtype: torch.dtype + block_size: Tuple[int] @register_quantize_module_handler(FbgemmConfig) @@ -2001,13 +2003,27 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: if fbgemm_gpu.__version__ < "1.2.0": raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - if config.io_dtype == "bf16i4bf16": - weight = to_fbgemm_int4(module.weight, config.group_size) + _SUPPORTED_DTYPES = { + (torch.bfloat16, torch.int4, torch.bfloat16), + } + + if ( + config.input_dtype, + config.weight_dtype, + config.output_dtype, + ) in _SUPPORTED_DTYPES: + weight = to_fbgemm_quantized( + module.weight, + config.input_dtype, + config.weight_dtype, + config.output_dtype, + config.block_size, + ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) else: raise NotImplementedError( - f"{config} is not supported. supported io_dtypes are: {list(FbgemmKernelIODtype)}" + f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}" ) From f7dadd1d949fdbe4feee4e188c6dca0f14d281d1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 18:36:09 -0700 Subject: [PATCH 08/12] remove enum --- torchao/quantization/quant_api.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 20d87e9ff9..8e1b0e2fad 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,7 +20,6 @@ import types import warnings from dataclasses import dataclass, field -from enum import Enum from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -1971,10 +1970,6 @@ def _fpx_weight_only_transform( return module -class FbgemmKernelIODtype(str, Enum): - bf16i4bf16 = "bf16i4bf16" - - @dataclass class FbgemmConfig(AOBaseConfig): """Quantization Config for fbgemm-genai kernels From d9fdf72672ca35822d6f42087dfdeda2eb30511c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 19:01:47 -0700 Subject: [PATCH 09/12] serializability update --- test/dtypes/test_fbgemm_quantized.py | 44 +++++++++++++++++++ test/dtypes/test_fbgemm_quantized_tensor.py | 2 +- .../quantization/test_config_serialization.py | 3 ++ torchao/core/config.py | 16 ++++++- torchao/dtypes/fbgemm_quantized_tensor.py | 4 +- torchao/quantization/quant_api.py | 6 +-- 6 files changed, 67 insertions(+), 8 deletions(-) create mode 100644 test/dtypes/test_fbgemm_quantized.py diff --git a/test/dtypes/test_fbgemm_quantized.py b/test/dtypes/test_fbgemm_quantized.py new file mode 100644 index 0000000000..fe2573530c --- /dev/null +++ b/test/dtypes/test_fbgemm_quantized.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + FbgemmConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import is_sm_at_least_90 + + +class TestFbgemmInt4Tensor(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_linear(self): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=(1, 128), + ) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dtypes/test_fbgemm_quantized_tensor.py b/test/dtypes/test_fbgemm_quantized_tensor.py index fe2573530c..7e802a02a7 100644 --- a/test/dtypes/test_fbgemm_quantized_tensor.py +++ b/test/dtypes/test_fbgemm_quantized_tensor.py @@ -33,7 +33,7 @@ def test_linear(self): input_dtype=torch.bfloat16, weight_dtype=torch.int4, output_dtype=torch.bfloat16, - block_size=(1, 128), + block_size=[1, 128], ) quantize_(linear, config) quantized = linear(input) diff --git a/test/quantization/test_config_serialization.py b/test/quantization/test_config_serialization.py index 3b0a10e915..0cf80d376e 100644 --- a/test/quantization/test_config_serialization.py +++ b/test/quantization/test_config_serialization.py @@ -20,6 +20,7 @@ config_to_dict, ) from torchao.quantization.quant_api import ( + FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, @@ -39,6 +40,7 @@ configs = [ Float8DynamicActivationFloat8WeightConfig(), Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]), Float8WeightOnlyConfig( weight_dtype=torch.float8_e4m3fn, ), @@ -76,6 +78,7 @@ "linear2": Int8DynamicActivationInt4WeightConfig(), } ), + FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]), ] diff --git a/torchao/core/config.py b/torchao/core/config.py index d2d49981c9..7198c30a0d 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -130,7 +130,13 @@ def default(self, o): # For lists and dictionaries, recursively process their items if isinstance(o, list): - return [self.encode_value(item) for item in o] + return type(o)(self.encode_value(item) for item in o) + + if isinstance(o, tuple): + raise NotImplementedError( + "Tuples will be serialized as List in JSON, so we recommend to use " + "Lists instead to avoid surprises" + ) if isinstance(o, dict): return {k: self.encode_value(v) for k, v in o.items()} @@ -245,18 +251,24 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: # Process nested structures for dictionary obj_data processed_data = {} + print("obj data:", obj_data) for key, value in obj_data.items(): if isinstance(value, dict) and "_type" in value and "_data" in value: # Recursively handle nested configs processed_data[key] = config_from_dict(value) elif isinstance(value, list): - # Handle lists of possible configs + # Handle lists or tuples of possible configs processed_data[key] = [ config_from_dict(item) if isinstance(item, dict) and "_type" in item and "_data" in item else item for item in value ] + elif isinstance(value, tuple): + raise NotImplementedError( + "Tuples will be serialized as List in JSON, so we recommend to use " + "Lists instead to avoid surprises" + ) elif isinstance(value, dict): # Handle dicts of possible configs processed_data[key] = { diff --git a/torchao/dtypes/fbgemm_quantized_tensor.py b/torchao/dtypes/fbgemm_quantized_tensor.py index 20fbd0058d..fd788a73a3 100644 --- a/torchao/dtypes/fbgemm_quantized_tensor.py +++ b/torchao/dtypes/fbgemm_quantized_tensor.py @@ -6,7 +6,7 @@ import importlib.util -from typing import Tuple +from typing import List import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -78,7 +78,7 @@ def from_float( input_dtype: torch.dtype, weight_dtype: torch.dtype, output_dtype: torch.dtype, - block_size: Tuple[int], + block_size: List[int], ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8e1b0e2fad..f76245c677 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,7 +20,7 @@ import types import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -1529,7 +1529,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[ - Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] + Union[FP8Granularity, List[FP8Granularity]] ] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True @@ -1983,7 +1983,7 @@ class FbgemmConfig(AOBaseConfig): input_dtype: torch.dtype weight_dtype: torch.dtype output_dtype: torch.dtype - block_size: Tuple[int] + block_size: List[int] @register_quantize_module_handler(FbgemmConfig) From 6df31fff141d2b947ccafb188e15eadfc47083d1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 19:08:21 -0700 Subject: [PATCH 10/12] format --- torchao/quantization/quant_api.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f76245c677..2bae49b1fb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1528,9 +1528,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype - granularity: Optional[ - Union[FP8Granularity, List[FP8Granularity]] - ] = None + granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True From b8437cccc9a410b37a1d3548bb1f8ed446b30891 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 21:28:05 -0700 Subject: [PATCH 11/12] fix tests --- test/dtypes/test_fbgemm_quantized_tensor.py | 6 +++++- torchao/core/config.py | 9 ++++----- torchao/quantization/quant_api.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/dtypes/test_fbgemm_quantized_tensor.py b/test/dtypes/test_fbgemm_quantized_tensor.py index 7e802a02a7..51b68dd977 100644 --- a/test/dtypes/test_fbgemm_quantized_tensor.py +++ b/test/dtypes/test_fbgemm_quantized_tensor.py @@ -17,12 +17,16 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import is_sm_at_least_90 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_90, +) class TestFbgemmInt4Tensor(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6") def test_linear(self): dtype = torch.bfloat16 device = "cuda" diff --git a/torchao/core/config.py b/torchao/core/config.py index 7198c30a0d..3451b90c59 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -130,12 +130,12 @@ def default(self, o): # For lists and dictionaries, recursively process their items if isinstance(o, list): - return type(o)(self.encode_value(item) for item in o) + return [self.encode_value(item) for item in o] - if isinstance(o, tuple): + elif isinstance(o, tuple): raise NotImplementedError( "Tuples will be serialized as List in JSON, so we recommend to use " - "Lists instead to avoid surprises" + f"Lists instead to avoid surprises. got: {o}" ) if isinstance(o, dict): @@ -251,7 +251,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: # Process nested structures for dictionary obj_data processed_data = {} - print("obj data:", obj_data) for key, value in obj_data.items(): if isinstance(value, dict) and "_type" in value and "_data" in value: # Recursively handle nested configs @@ -267,7 +266,7 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: elif isinstance(value, tuple): raise NotImplementedError( "Tuples will be serialized as List in JSON, so we recommend to use " - "Lists instead to avoid surprises" + f"Lists instead to avoid surprises. got: {value}" ) elif isinstance(value, dict): # Handle dicts of possible configs diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2bae49b1fb..ada19859bc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1539,7 +1539,7 @@ def __post_init__(self): activation_granularity, weight_granularity = _normalize_granularity( self.granularity ) - self.granularity = (activation_granularity, weight_granularity) + self.granularity = [activation_granularity, weight_granularity] # for bc From d2066dc01744adcb800a609d7f2f097b6b21fdee Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 May 2025 21:39:23 -0700 Subject: [PATCH 12/12] skip fbgemm config tests for 2.5 and below --- test/quantization/test_config_serialization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_config_serialization.py b/test/quantization/test_config_serialization.py index 0cf80d376e..71cf8e144d 100644 --- a/test/quantization/test_config_serialization.py +++ b/test/quantization/test_config_serialization.py @@ -35,6 +35,7 @@ UIntXWeightOnlyConfig, ) from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # Define test configurations as fixtures configs = [ @@ -78,9 +79,11 @@ "linear2": Int8DynamicActivationInt4WeightConfig(), } ), - FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]), ] +if TORCH_VERSION_AT_LEAST_2_6: + configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])] + # Create ids for better test naming def get_config_ids(configs):