From 05814514de795f65490224e1334364e3d196d3d1 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 25 Apr 2025 03:20:39 -0700 Subject: [PATCH 01/18] [CPU] enable int8_dynamic_activation_int4_weight with Int4CPULayout --- test/quantization/test_quant_api.py | 25 +++++++++++++++++++++++++ torchao/dtypes/uintx/int4_cpu_layout.py | 5 +++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cce8d17c19..47a645dbca 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -863,6 +863,31 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + def test_8da4w_cpu(self, dtype, x_dim): + device = "cpu" + m = ToyLinearModel().eval().to(dtype).to(device) + example_inputs = m.example_inputs(dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_( + m, + int8_dynamic_activation_int4_weight( + group_size=32, layout=Int4CPULayout() + ), + ) + # ensure the expected op is in the code + _, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + assert "_weight_int4pack_mm_for_cpu" in code[0] + assert "aten.mm.default" not in code[0] + # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 4ccfb11d23..c5598de4db 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -113,8 +113,9 @@ def from_plain( if TORCH_VERSION_AT_LEAST_2_6: assert ( - int_data.dtype == torch.int32 - ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + int_data.dtype in [torch.int32, torch.int8] + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` or `int8` dtype" + int_data = int_data.to(torch.int32) packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( int_data, 1, # TODO:remove From 9fb7f778e05ff47f73fe7ca1fce839e1e474a99f Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 25 Apr 2025 03:27:13 -0700 Subject: [PATCH 02/18] Fix format issue --- torchao/dtypes/uintx/int4_cpu_layout.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 3d3b7014a3..605e71ed6a 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -112,9 +112,9 @@ def from_plain( assert isinstance(_layout, Int4CPULayout) if TORCH_VERSION_AT_LEAST_2_6: - assert ( - int_data.dtype in [torch.int32, torch.int8] - ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` or `int8` dtype" + assert int_data.dtype in [torch.int32, torch.int8], ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` or `int8` dtype" + ) int_data = int_data.to(torch.int32) packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( int_data, From 8e80d03a68856840a55f5f95f7e29183b7d8fdc7 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 14 May 2025 03:00:49 -0700 Subject: [PATCH 03/18] Add Int8DynamicActInt4WeightCPULayout --- test/quantization/test_quant_api.py | 23 +++- torchao/dtypes/__init__.py | 2 + torchao/dtypes/uintx/__init__.py | 2 + torchao/dtypes/uintx/int4_cpu_layout.py | 142 +++++++++++++++++++++++- 4 files changed, 158 insertions(+), 11 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 7b44514c26..29d8b02d22 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -29,6 +29,7 @@ AffineQuantizedTensor, Int4CPULayout, Int4XPULayout, + Int8DynamicActInt4WeightCPULayout, PlainLayout, QDQLayout, TensorCoreTiledLayout, @@ -881,24 +882,36 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): def test_8da4w_cpu(self, dtype, x_dim): device = "cpu" m = ToyLinearModel().eval().to(dtype).to(device) + m2 = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=dtype, device=device) if x_dim == 3: example_inputs = (example_inputs[0].unsqueeze(0),) with torch.no_grad(): + # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout + # is that the former packs two int4 weights into one int8, while the latter does not. quantize_( m, int8_dynamic_activation_int4_weight( - group_size=32, layout=Int4CPULayout() + group_size=32, layout=Int8DynamicActInt4WeightCPULayout() ), ) - # ensure the expected op is in the code - _, code = torch._inductor.utils.run_and_get_code( + y, code = torch._inductor.utils.run_and_get_code( torch.compile(m, fullgraph=True, dynamic=True), *example_inputs, ) - assert "_weight_int4pack_mm_for_cpu" in code[0] - assert "aten.mm.default" not in code[0] + # ensure the expected op is in the code + assert "shift" in code[0] # unpacking int4 values + assert "extern_kernels.mm" in code[0] + quantize_( + m2, + int8_dynamic_activation_int4_weight( + group_size=32, layout=PlainLayout() + ), + ) + torch._dynamo.reset() # may segfault without this + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + assert torch.allclose(y, y2) # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index eb253c11bc..a9ea914f62 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -18,6 +18,7 @@ CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, + Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -61,4 +62,5 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index fee6141164..7c681fd52c 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -6,6 +6,7 @@ ) from .int4_cpu_layout import ( Int4CPULayout, + Int8DynamicActInt4WeightCPULayout, ) from .int4_xpu_layout import ( Int4XPULayout, @@ -48,4 +49,5 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "QDQLayout", "Int4XPULayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 605e71ed6a..cb0d27a42b 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -112,10 +112,9 @@ def from_plain( assert isinstance(_layout, Int4CPULayout) if TORCH_VERSION_AT_LEAST_2_6: - assert int_data.dtype in [torch.int32, torch.int8], ( - "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` or `int8` dtype" + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" ) - int_data = int_data.to(torch.int32) packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( int_data, 1, # TODO:remove @@ -148,7 +147,7 @@ def to(self, *args, **kwargs): device = kwargs["device"] if not is_device(torch.device(self.device).type, device): raise ValueError( - f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" + f"{self.__class__.__name__} does not support conversion from {self.device} to {device}" ) return self.__class__( self.packed_weight.to(device), @@ -215,11 +214,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, sliced) else: raise NotImplementedError( - f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" ) raise NotImplementedError( - f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -353,3 +352,134 @@ def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): if bias is not None: y += bias return y.to(orig_dtype) + + +@dataclass(frozen=True) +class Int8DynamicActInt4WeightCPULayout(Layout): + """Layout class for da8w4 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Int8DynamicActInt4WeightCPULayout) +class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): + """TensorImpl for da8w4 CPU layout for affine quantized tensor + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor + qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.qzeros = qzeros + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales", "qzeros"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales, qzeros = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + tensor_data_dict["qzeros"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scales, qzeros, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) + assert int_data.dtype == torch.int8, "DA8W4 CPU: expects int8 weight" + assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" + weight_int4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) + return cls(weight_int4, scale, zero_point, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + fn(self.qzeros), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = DA8W4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scales, + args[0].qzeros, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + else: + return super().__torch_dispatch__(func, types, args, kwargs) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] * 2 + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + plain_weight = torch.stack( + ((self.packed_weight << 4) >> 4, self.packed_weight >> 4), dim=-1 + ).view(self.packed_weight.shape[:-1] + (2 * self.packed_weight.shape[-1],)) + return plain_weight, self.scales, self.qzeros From 3e201729298a5c856fa88053a5c20adae7138132 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 15 May 2025 22:52:40 -0700 Subject: [PATCH 04/18] remove dispatch for t() --- torchao/dtypes/uintx/int4_cpu_layout.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index cb0d27a42b..60a81cdb33 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -178,18 +178,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = Int4CPUAQTTensorImpl( - args[0].packed_weight, - args[0].scale_and_zero, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim in [0, 1]: From 4feac3f24ddb6464b47fda4384ad7656c5fe279b Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 23 May 2025 00:30:48 -0700 Subject: [PATCH 05/18] Add cpp kernel for weight packing and GEMM --- test/quantization/test_quant_api.py | 33 +- torchao/csrc/cpu/da8w4_linear.cpp | 570 ++++++++++++++++++++++++ torchao/dtypes/uintx/int4_cpu_layout.py | 74 ++- torchao/ops.py | 82 ++++ torchao/quantization/quant_api.py | 6 + 5 files changed, 747 insertions(+), 18 deletions(-) create mode 100644 torchao/csrc/cpu/da8w4_linear.cpp diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 29d8b02d22..d613633f83 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -880,6 +880,13 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) def test_8da4w_cpu(self, dtype, x_dim): + print( + "========================= dtype =", + dtype, + ", x_dim =", + x_dim, + "=========================", + ) device = "cpu" m = ToyLinearModel().eval().to(dtype).to(device) m2 = copy.deepcopy(m) @@ -890,27 +897,35 @@ def test_8da4w_cpu(self, dtype, x_dim): with torch.no_grad(): # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout # is that the former packs two int4 weights into one int8, while the latter does not. + print(">>> quantize with Int8DynamicActInt4WeightCPULayout") quantize_( m, int8_dynamic_activation_int4_weight( group_size=32, layout=Int8DynamicActInt4WeightCPULayout() ), ) - y, code = torch._inductor.utils.run_and_get_code( - torch.compile(m, fullgraph=True, dynamic=True), - *example_inputs, - ) - # ensure the expected op is in the code - assert "shift" in code[0] # unpacking int4 values - assert "extern_kernels.mm" in code[0] + # y, code = torch._inductor.utils.run_and_get_code( + # torch.compile(m, fullgraph=True, dynamic=True), + # *example_inputs, + # ) + # # ensure the expected op is in the code + # assert "shift" in code[0] # unpacking int4 values + # assert "extern_kernels.mm" in code[0] + print(">>> run with Int8DynamicActInt4WeightCPULayout") + y = m(*example_inputs) + print(">>> quantize with PlainLayout") quantize_( m2, int8_dynamic_activation_int4_weight( - group_size=32, layout=PlainLayout() + group_size=32, + # mapping_type=MappingType.ASYMMETRIC, + layout=PlainLayout(), ), ) torch._dynamo.reset() # may segfault without this - y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + print(">>> run with PlainLayout") + # y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + y2 = m2(*example_inputs) assert torch.allclose(y, y2) # TODO(#1690): move to new config names diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp new file mode 100644 index 0000000000..404bc0857e --- /dev/null +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -0,0 +1,570 @@ +#include +// #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include + +namespace torchao { + +namespace { + +static bool use_cpublas_checked = false; +static bool use_cpublas = false; + +bool da8w4_can_pack_weight() { +#if defined(CPU_CAPABILITY_AVX512) + if (use_cpublas_checked) { + return use_cpublas; + } + use_cpublas = at::native::cpublas::could_pack(at::kByte); + use_cpublas_checked = true; + return use_cpublas; +#else + return false; +#endif +} + +/* +return: packed_weight, packed_scales, packed_qzeros, compensation +*/ +std::tuple +da8w4_linear_prepack_impl( + const at::Tensor& weight, + const at::Tensor& scales, + const at::Tensor& qzeros) { + // weight shape = [N, K] + // scales shape = [N, G] + // qzeros shape = [N, G] + TORCH_CHECK(weight.dim() == 2, + "DA8W4 CPU: Weight should be a 2D tensor for packing"); + TORCH_CHECK(weight.size(1) % 2 == 0, + "DA8W4 CPU: Weight should have even number of columns for packing"); + + auto new_scales = scales; + auto new_qzeros = qzeros; + if (new_scales.dim() == 1) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + if (new_qzeros.dim() == 1) { + new_qzeros.unsqueeze_(1); + } + new_qzeros = new_qzeros.to(at::kChar); + int N = weight.size(0); + int K = weight.size(1); + int G = scales.size(1); + int group_size = K / G; + int block_k = group_size > 128 ? 128 : group_size; + constexpr int block_n = 32; + int Nc = N / block_n; + int Kc = K / block_k; + + // Reorder weight to [N/block_n, K/block_k, block_k, block_n] + // Reorder scales/qzeros to [N/block_n, G, block_n] + auto weight_view = weight.view({Nc, block_n, Kc, block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + at::Tensor blocked_qzeros = new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + at::Tensor compensation = weight_view.sum(-1).permute({0, 2, 1}).contiguous().to(at::kInt); + + if (da8w4_can_pack_weight()) { + blocked_weight = at::empty({Nc, Kc, block_k, block_n / 2}, weight.options()); + auto weight_ptr = weight_reordered.data_ptr(); + auto blocked_weight_ptr = blocked_weight.data_ptr(); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * block_k * block_n; + auto out_ptr = blocked_weight_ptr + i * block_k * block_n / 2; + + // Reorder weight block to VNNI4 and pack two lanes along N + // N=16 viewed as two lanes: a0, ...a7, b0, ...b7 + // pack two lanes: [a0, b0], ..., [a7, b7] + // plain shape = [block_k, block_n] + // packed shape = [block_k / 4, block_n / 2, 4] viewed as [block_k, block_n / 2] + constexpr int n_group_size = 8; + constexpr int vnni_size = 4; + constexpr int n_group = block_n / n_group_size; // 4 + for (int nb = 0; nb < n_group; nb += 2) { + for (int k = 0; k < block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n; + int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size + k * block_n / 2 + ki; + uint8_t src_1 = *(in_ptr + src_idx_1); + uint8_t src_2 = *(in_ptr + src_idx_2); + uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4); + *(out_ptr + dst_idx) = dst; + } + } + } + } + } + }); + } else { + // Pack weight: two int4 -> one int8 + using namespace at::indexing; + at::Tensor even_columns = + weight_reordered.index({Slice(), Slice(), Slice(), Slice(1, None, 2)}); + even_columns = even_columns.bitwise_left_shift(4); + at::Tensor odd_columns = + weight_reordered.index({Slice(), Slice(), Slice(), Slice(None, None, 2)}); + blocked_weight = even_columns.bitwise_or(odd_columns); + } + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales), std::move(blocked_qzeros), std::move(compensation)); +} + +#if defined(CPU_CAPABILITY_AVX512) +inline std::array<__m256i, 2> load_zps_4vnni(const int8_t* __restrict__ zps) { + // broadcast 01234567 to + // 01234567012345670123456701234567 + __m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast(zps)); + __m256i vzps_high = _mm256_set1_epi64x(*reinterpret_cast(zps + 8)); + // shuffle from + // 01234567012345670123456701234567 + // to + // 00001111222233334444555566667777 + __m256i shuffle_mask = _mm256_set_epi8( + 7, + 7, + 7, + 7, + 6, + 6, + 6, + 6, + 5, + 5, + 5, + 5, + 4, + 4, + 4, + 4, + 3, + 3, + 3, + 3, + 2, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0); + vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask); + vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask); + return {vzps_low, vzps_high}; +} + +inline std::array<__m256i, 2> load_uint4_as_int8(const uint8_t* __restrict__ qB) { + __m256i packed = _mm256_loadu_si256(reinterpret_cast(qB)); + const __m256i low_mask = _mm256_set1_epi8(0x0f); + __m256i high = _mm256_srli_epi16(packed, 4); + high = _mm256_and_si256(high, low_mask); + __m256i low = _mm256_and_si256(packed, low_mask); + return {low, high}; +} + +void _dequant_weight_zp_only( + const uint8_t* __restrict__ B, + int8_t* dqB, + const int8_t* __restrict__ qzeros, + int64_t N, + int64_t K, + int64_t ldb) { + // unpack weight int8 -> two int4 + // subtract zero point + // B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4] + // dqB shape = [K, N], actual shape = [K / 4, N, 4] + for (int n = 0; n < N; n += 16) { + auto [zps_low, zps_high] = load_zps_4vnni(&qzeros[n]); + for (int k = 0; k < K; k += 4) { + auto [vb_low, vb_high] = load_uint4_as_int8(B + ldb * k + n / 2 * 4); + vb_high = _mm256_sub_epi8(vb_high, zps_high); + vb_low = _mm256_sub_epi8(vb_low, zps_low); + // store vb to B + _mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4), vb_low); + _mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high); + } + } +} + +template +void _dequant_and_store( + float* __restrict__ output, + const int32_t* __restrict__ input, + const float* __restrict__ scale_a, + const int8_t* __restrict__ zp_a, + const float* __restrict__ scale_b, + const int32_t* __restrict__ comp_b, + int M, + int N, + int ldi, + int ldo, + int ldsa = 1) { +#pragma GCC unroll 2 + for (int m = 0; m < M; ++m) { + float a_scale = *(scale_a + m * ldsa); + int32_t a_zp = (int32_t)(*(zp_a + m * ldsa)); + __m512 va_scale = _mm512_set1_ps(a_scale); + __m512i va_zp = _mm512_set1_epi32(a_zp); + int n = 0; + for (; n < N; n += 16) { + __m512i va = _mm512_loadu_si512(input + m * ldi + n); + __m512i vb_comp = _mm512_loadu_si512(comp_b + n); + __m512i vc = _mm512_sub_epi32(va, _mm512_mullo_epi32(vb_comp, va_zp)); + __m512 vc_f = _mm512_cvtepi32_ps(vc); + __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); + __m512 vb_s = _mm512_loadu_ps(scale_b + n); + vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s); + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f_mul); + } + } + for (; n < N; ++n) { + float dq_val = + (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale * scale_b[n]; + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +#else +void _dequant_weight_zp_only( + const uint8_t* B, + int8_t* dqB, + const int8_t* qzeros, + int64_t N, + int64_t K, + int64_t ldb) { + // B shape = [K, N / 2] + // dqB shape = [K, N] + for (int k = 0; k < K; ++k) { + for (int n = 0; n < N / 2; ++n) { + int32_t b = (int32_t)B[k * ldb + n]; + dqB[k * N + n * 2] = (b & 0xf) - qzeros[n]; + dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n]; + } + } +} +#endif + +template +void _dequant_gemm_accum( + float* C, + const uint8_t* A, + const float* scales_a, + const int8_t* qzeros_a, + const uint8_t* B, + const float* scales_b, + const int8_t* qzeros_b, + const int32_t* compensation, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + // Compute GEMM int8 * int8 -> int32 + // dequant result to float by applying scales/qzeros + + int8_t dqB[K * N]; + _dequant_weight_zp_only(B, dqB, qzeros_b, N, K, ldb); +#if defined(CPU_CAPABILITY_AVX512) + if constexpr (use_cpublas) { + int32_t C_i32[M * N]; + at::native::cpublas::brgemm( + M, + N, + K, + lda, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + A, + dqB, + C_i32, + true /* is_vnni */); + _dequant_and_store( + C, + C_i32, + scales_a, + qzeros_a, + scales_b, + compensation, + M, + N, + N /*ldi*/, + ldc, + 1 /*ldsa*/); + } else +#endif + { + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0; + for (int64_t k = 0; k < K; ++k) { + sum += ((int32_t)A[i * lda + k] - qzeros_a[i]) * (int32_t)dqB[k * N + j]; + } + C[i * ldc + j] += sum * scales_a[i] * scales_b[j]; + } + } + } +} + +inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m, int64_t n) { + if (bias_ptr) { + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) + for (; j < n; j += 16) { + __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); + _mm512_storeu_ps(y_buf + i * n + j, bias_vec); + } +#endif + for (; j < n; ++j) { + y_buf[i * n + j] = bias_ptr[j]; + } + } + } else { // initialize to zero + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) + for (; j < n; j += 16) { + __m512 zero_vec = _mm512_setzero_ps(); + _mm512_storeu_ps(y_buf + i * n + j, zero_vec); + } +#endif + for (; j < n; ++j) { + y_buf[i * n + j] = 0; + } + } + } +} + +template +inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, int64_t n, int64_t lda) { + for (int i = 0; i < m; ++i) { + int j = 0; + if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) + for (; j < n; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * n + j); + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } +#endif + for (; j < n; ++j) { + c_ptr[i * lda + j] = y_buf[i * n + j]; + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) + for (; j < n; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * n + j); + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } +#endif + for (; j < n; ++j) { + c_ptr[i * lda + j] = at::BFloat16(y_buf[i * n + j]); + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) + for (; j < n; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * n + j); + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } +#endif + for (; j < n; ++j) { + c_ptr[i * lda + j] = at::Half(y_buf[i * n + j]); + } + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + } +} + +template +void _da8w4_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& input_qzeros, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const at::Tensor& weight_qzeros, + const at::Tensor& compensation, + const std::optional& bias, + at::Tensor& output) { + // input shape = [..., K] + // input is per token quantized + int64_t K = input.size(-1); + auto input_view = input.view({-1, K}); + int64_t M = input_view.size(0); + TORCH_CHECK(input_scales.numel() == M, "DA8W4: unexpected input scales shape"); + TORCH_CHECK(input_scales.sizes() == input_qzeros.sizes(), "DA8W4: unexpected input qzeros shape"); + + // weight shape = [Nc, Kc, block_k, block_n/2] + // scales/qzeros shape = [Nc, G, block_n] + // compensation shape = [Nc, Kc, block_n] + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t block_k = weight.size(2); + int64_t block_n = weight.size(3) * 2; + int64_t N = Nc * block_n; + TORCH_CHECK(K == Kc * block_k, "DA8W4: weight and input shapes mismatch"); + int64_t block_m = [&]() -> long { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 48; + } else { + return 64; + } + }(); + int64_t Mc = (M + block_m - 1) / block_m; + bool parallel_on_M = M > 128; + int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; + + // scales/qzeros shape = [Nc, G, block_n] + int64_t num_groups = weight_scales.size(1); + int64_t group_size = K / num_groups; + TORCH_CHECK(group_size % block_k == 0, + "DA8W4 CPU: group_size should be divisible by block_k"); + int64_t block_per_group = group_size / block_k; + + const uint8_t* a_ptr = input_view.data_ptr(); + const float* a_scales_ptr = input_scales.data_ptr(); + const int8_t* a_qzeros_ptr = input_qzeros.data_ptr(); + const uint8_t* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr(); + const int32_t* compensation_ptr = compensation.data_ptr(); + out_dtype* c_ptr = output.data_ptr(); + const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + int64_t mc = parallel_on_M ? i / Nc : 0; + int64_t nc = parallel_on_M ? i % Nc : i; + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; + alignas(64) float y_buf[m_size][block_n]; + // copy bias to y_buf if bias is not None + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + copy_bias(bias_data, y_buf[0], m_size, block_n); + for (int kci = 0; kci < Kc; ++kci) { + _dequant_gemm_accum( + y_buf[0] /*C*/, + a_ptr + mci * block_m * K + kci * block_k /*A*/, + a_scales_ptr + mci * block_m /*scakes_a*/, + a_qzeros_ptr + mci * block_m /*qzeros_a*/, + b_ptr + (nc * Kc + kci) * block_n * block_k / 2 /*B*/, + b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*scales_b*/, + b_qzeros_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*qzeros_b*/, + compensation_ptr + nc * block_n * Kc + kci * block_n /*compensation*/, + m_size /*M*/, + block_n /*N*/, + block_k /*K*/, + K /*lda*/, + block_n / 2 /*ldb*/, + block_n /*ldc*/); + } + // store y_buf to output + store_out(y_buf[0], c_ptr + mci * block_m * N + nc * block_n, m_size, block_n, N); + } + } + }); +} + +at::Tensor da8w4_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& input_qzeros, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const at::Tensor& weight_qzeros, + const at::Tensor& compensation, + const std::optional& bias, + at::ScalarType output_dtype) { + static bool use_cpublas = da8w4_can_pack_weight(); + auto out_sizes = input.sizes().vec(); + int64_t N = weight.size(0) * weight.size(-1) * 2; + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); + if (use_cpublas) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { + _da8w4_linear_impl( + input, + input_scales, + input_qzeros, + weight, + weight_scales, + weight_qzeros, + compensation, + bias, + output); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { + _da8w4_linear_impl( + input, + input_scales, + input_qzeros, + weight, + weight_scales, + weight_qzeros, + compensation, + bias, + output); + }); + } + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::da8w4_linear_prepack_cpu", &da8w4_linear_prepack_impl); + m.impl("torchao::da8w4_linear_cpu", &da8w4_linear_impl); +} + +} // namespace torchao diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 60a81cdb33..8b61bba670 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -366,6 +366,7 @@ def __new__( packed_weight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, + compensation: torch.Tensor, transposed: bool, _layout: Layout, ): @@ -386,17 +387,22 @@ def __init__( packed_weight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, + compensation: torch.Tensor, transposed: bool, _layout: Layout, ): self.packed_weight = packed_weight self.scales = scales self.qzeros = qzeros + self.compensation = compensation self.transposed = transposed self._layout = _layout def __tensor_flatten__(self): - return ["packed_weight", "scales", "qzeros"], [self.transposed, self._layout] + return ["packed_weight", "scales", "qzeros", "compensation"], [ + self.transposed, + self._layout, + ] @classmethod def __tensor_unflatten__( @@ -406,6 +412,7 @@ def __tensor_unflatten__( tensor_data_dict["packed_weight"], tensor_data_dict["scales"], tensor_data_dict["qzeros"], + tensor_data_dict["compensation"], ) ( transposed, @@ -418,20 +425,32 @@ def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: Optional[torch.Tensor], + zero_point: torch.Tensor, _layout: Layout, ): assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) - assert int_data.dtype == torch.int8, "DA8W4 CPU: expects int8 weight" + assert int_data.dtype == torch.int8, "DA8W4 CPU: expects uint8 weight" assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" - weight_int4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) - return cls(weight_int4, scale, zero_point, False, _layout) + # int8 -> uint8 + int_data = (int_data + 8).to(torch.uint8) + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + if zero_point.dim() == 1: + zero_point.unsqueeze_(-1) + zero_point = zero_point.to(torch.int8) + 8 + + weight_int4, scales, qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) + ) + return cls(weight_int4, scales, qzeros, compensation, False, _layout) def _apply_fn_to_data(self, fn): return self.__class__( fn(self.packed_weight), fn(self.scales), fn(self.qzeros), + fn(self.compensation), self.transposed, self._layout, ) @@ -447,6 +466,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): args[0].packed_weight, args[0].scales, args[0].qzeros, + args[0].compensation, not args[0].transposed, args[0]._layout, ) @@ -467,7 +487,43 @@ def block_size(self): return (1, group_size) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - plain_weight = torch.stack( - ((self.packed_weight << 4) >> 4, self.packed_weight >> 4), dim=-1 - ).view(self.packed_weight.shape[:-1] + (2 * self.packed_weight.shape[-1],)) - return plain_weight, self.scales, self.qzeros + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.uint8) + x_scale = torch.ones(K).float() + x_qzero = torch.zeros(K).to(torch.int8) + w_scale = torch.ones_like(self.scales).float() + w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) + plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( + x, + x_scale, + x_qzero, + self.packed_weight, + w_scale, + w_qzero, + self.compensation, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.int8) + + if self.scales.dim() == 2: + assert self.qzeros.dim() == 2 + plain_scales = self.scales + plain_qzeros = self.qzeros + else: + assert self.scales.dim() == 3 and self.qzeros.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + plain_qzeros = ( + self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, plain_qzeros diff --git a/torchao/ops.py b/torchao/ops.py index faebdbd5d1..23c6f41985 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -61,6 +61,12 @@ lib.define( "qscaled_dot_product(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float? scale=None, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor" ) +lib.define( + "da8w4_linear_prepack_cpu(Tensor weight, Tensor scales, Tensor qzeros) -> (Tensor, Tensor, Tensor, Tensor)" +) +lib.define( + "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" +) def register_custom_op(name): @@ -959,3 +965,79 @@ def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): """Meta impl for mx_fp4_bf16""" # Assume that the contraction happens in the K dim thus M,N are perserved post bit pack return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device) + + +def da8w4_linear_prepack_cpu( + weight: Tensor, + scales: Tensor, + qzeros: Tensor, +) -> Tensor: + """ + Prepack weights for DA8W4 linear operator on CPU. + Args: + weight: weight tensor. + scales: scales for weight tensor. + qzeros: zero points for weight tensor. + Returns: + packed weight, scales, and zero points. + """ + return torch.ops.torchao.da8w4_linear_prepack_cpu.default(weight, scales, qzeros) + + +@register_custom_op("torchao::da8w4_linear_prepack_cpu") +def _(weight: Tensor, scales: Tensor, qzeros: Tensor) -> Tensor: + return weight, scales, qzeros, torch.Tensor() + + +def da8w4_linear_cpu( + input: Tensor, + input_scales: Tensor, + input_qzeros: Tensor, + weight: Tensor, + weight_scales: Tensor, + weight_qzeros: Tensor, + compensation: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +): + """ + DA8W4 linear operator on CPU. + Args: + input: input tensor. + input_scales: scales for input tensor. + input_qzeros: zero points for input tensor. + weight: weight tensor. + weight_scales: scales for weight tensor. + weight_qzeros: zero points for weight tensor. + compensation: compensation tensor for weight. + bias: optional bias tensor. + out_dtype: output data type. + Returns: + output tensor in out_dtype. + """ + return torch.ops.torchao.da8w4_linear_cpu.default( + input, + input_scales, + input_qzeros, + weight, + weight_scales, + weight_qzeros, + compensation, + bias, + out_dtype, + ) + + +@register_custom_op("torchao::da8w4_linear_cpu") +def _( + input: Tensor, + input_scales: Tensor, + input_qzeros: Tensor, + weight: Tensor, + weight_scales: Tensor, + weight_qzeros: Tensor, + compensation: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +) -> Tensor: + return input.new_empty(*input.shape[:-1], weight.shape[0], dtype=out_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 15e3b20fc8..e74bcf5a60 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -34,6 +34,7 @@ Float8Layout, Int4CPULayout, Int4XPULayout, + Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinSparseLayout, PackedLinearInt8DynamicActivationIntxWeightLayout, @@ -715,6 +716,11 @@ def _int8_dynamic_activation_int4_weight_transform( quant_min = -8 quant_max = 7 + if isinstance(layout, Int8DynamicActInt4WeightCPULayout): + # Int8DynamicActInt4WeightCPULayout requires bias to be in float32 + if module.bias is not None: + module.bias = torch.nn.Parameter(module.bias.float(), requires_grad=False) + # input settings if act_mapping_type == MappingType.ASYMMETRIC: input_quant_func = _int8_asymm_per_token_quant From 0d85183dee85002d0c09334339e87c2d265c5667 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sun, 25 May 2025 07:53:42 -0700 Subject: [PATCH 06/18] Register ATQ linear dispatch for da8w4 linear --- test/quantization/test_quant_api.py | 26 +++--- torchao/csrc/cpu/da8w4_linear.cpp | 14 +-- torchao/dtypes/affine_quantized_tensor_ops.py | 6 ++ torchao/dtypes/uintx/int4_cpu_layout.py | 90 ++++++++++++++++++- torchao/quantization/quant_api.py | 37 +++++++- 5 files changed, 150 insertions(+), 23 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index d613633f83..2daaf0bd6c 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -71,6 +71,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, @@ -876,19 +877,13 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) - def test_8da4w_cpu(self, dtype, x_dim): - print( - "========================= dtype =", - dtype, - ", x_dim =", - x_dim, - "=========================", - ) + @common_utils.parametrize("bias", [True, False]) + def test_8da4w_cpu(self, dtype, x_dim, bias): device = "cpu" - m = ToyLinearModel().eval().to(dtype).to(device) + m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) m2 = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=dtype, device=device) if x_dim == 3: @@ -897,7 +892,6 @@ def test_8da4w_cpu(self, dtype, x_dim): with torch.no_grad(): # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout # is that the former packs two int4 weights into one int8, while the latter does not. - print(">>> quantize with Int8DynamicActInt4WeightCPULayout") quantize_( m, int8_dynamic_activation_int4_weight( @@ -911,9 +905,7 @@ def test_8da4w_cpu(self, dtype, x_dim): # # ensure the expected op is in the code # assert "shift" in code[0] # unpacking int4 values # assert "extern_kernels.mm" in code[0] - print(">>> run with Int8DynamicActInt4WeightCPULayout") y = m(*example_inputs) - print(">>> quantize with PlainLayout") quantize_( m2, int8_dynamic_activation_int4_weight( @@ -923,10 +915,14 @@ def test_8da4w_cpu(self, dtype, x_dim): ), ) torch._dynamo.reset() # may segfault without this - print(">>> run with PlainLayout") # y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) y2 = m2(*example_inputs) - assert torch.allclose(y, y2) + atol, rtol = 1e-7, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 0.01, 1e-3 + elif dtype == torch.half: + atol, rtol = 0.005, 1e-3 + assert torch.allclose(y, y2, atol=atol, rtol=rtol) # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index 404bc0857e..65cbb2cd64 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -83,7 +83,9 @@ da8w4_linear_prepack_impl( at::Tensor blocked_weight; at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); at::Tensor blocked_qzeros = new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); - at::Tensor compensation = weight_view.sum(-1).permute({0, 2, 1}).contiguous().to(at::kInt); + // weight was increased by 8 during quantization, so we need to subtract 8 + at::Tensor compensation = weight_view.to(at::kInt).sub(8).sum(-1); + compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt); if (da8w4_can_pack_weight()) { blocked_weight = at::empty({Nc, Kc, block_k, block_n / 2}, weight.options()); @@ -220,7 +222,7 @@ void _dequant_and_store( float* __restrict__ output, const int32_t* __restrict__ input, const float* __restrict__ scale_a, - const int8_t* __restrict__ zp_a, + const int32_t* __restrict__ zp_a, const float* __restrict__ scale_b, const int32_t* __restrict__ comp_b, int M, @@ -231,7 +233,7 @@ void _dequant_and_store( #pragma GCC unroll 2 for (int m = 0; m < M; ++m) { float a_scale = *(scale_a + m * ldsa); - int32_t a_zp = (int32_t)(*(zp_a + m * ldsa)); + int32_t a_zp = *(zp_a + m * ldsa); __m512 va_scale = _mm512_set1_ps(a_scale); __m512i va_zp = _mm512_set1_epi32(a_zp); int n = 0; @@ -287,7 +289,7 @@ void _dequant_gemm_accum( float* C, const uint8_t* A, const float* scales_a, - const int8_t* qzeros_a, + const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, const int8_t* qzeros_b, @@ -469,7 +471,7 @@ void _da8w4_linear_impl( const uint8_t* a_ptr = input_view.data_ptr(); const float* a_scales_ptr = input_scales.data_ptr(); - const int8_t* a_qzeros_ptr = input_qzeros.data_ptr(); + const int32_t* a_qzeros_ptr = input_qzeros.data_ptr(); const uint8_t* b_ptr = weight.data_ptr(); const float* b_scales_ptr = weight_scales.data_ptr(); const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr(); @@ -493,7 +495,7 @@ void _da8w4_linear_impl( _dequant_gemm_accum( y_buf[0] /*C*/, a_ptr + mci * block_m * K + kci * block_k /*A*/, - a_scales_ptr + mci * block_m /*scakes_a*/, + a_scales_ptr + mci * block_m /*scales_a*/, a_qzeros_ptr + mci * block_m /*qzeros_a*/, b_ptr + (nc * Kc + kci) * block_n * block_k / 2 /*B*/, b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*scales_b*/, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 1d70f5c7f3..fa82d92f9f 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -42,6 +42,8 @@ from torchao.dtypes.uintx.int4_cpu_layout import ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, ) from torchao.dtypes.uintx.int4_xpu_layout import ( _linear_bf16_act_uint4_weight_float_zero_check, @@ -242,6 +244,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_bf16_act_uint4_weight_float_zero_check, _linear_bf16_act_uint4_weight_float_zero_impl, ), + ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 8b61bba670..ad6d7d1d8d 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -16,11 +16,12 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, fill_defaults, ) @@ -527,3 +528,90 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) return plain_weight, plain_scales, plain_qzeros + + +def _aqt_is_uint8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 255 + ) + + +def _aqt_is_int4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == -8 + and aqt.quant_max == 7 + ) + + +def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): + ret = ( + TORCH_VERSION_AT_LEAST_2_7 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_uint8(input_tensor) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int4(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) + ) + return ret + + +def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert TORCH_VERSION_AT_LEAST_2_7, ( + f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + act_qzeros = act_mat.tensor_impl.zero_point + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + wei_qzeros = weight_tensor.tensor_impl.qzeros + compensation = weight_tensor.tensor_impl.compensation + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act.contiguous(), + act_scales, + act_qzeros, + packed_weight, + wei_scales, + wei_qzeros, + compensation, + bias, + orig_dtype, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e74bcf5a60..835d0afec9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -647,6 +647,38 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) +def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.uint8 + scale_dtype = torch.float32 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + quant_min = 0 + quant_max = 255 + if TORCH_VERSION_AT_LEAST_2_6: + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) + else: + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + return out + + def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -723,7 +755,10 @@ def _int8_dynamic_activation_int4_weight_transform( # input settings if act_mapping_type == MappingType.ASYMMETRIC: - input_quant_func = _int8_asymm_per_token_quant + if isinstance(layout, Int8DynamicActInt4WeightCPULayout): + input_quant_func = _uint8_asymm_per_token_quant + else: + input_quant_func = _int8_asymm_per_token_quant elif act_mapping_type == MappingType.SYMMETRIC: if isinstance(layout, MarlinQQQLayout): input_quant_func = _int8_symm_per_token_quant From c42abdbc3c49af733c2df72a60fc9b2658cd69ed Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sun, 25 May 2025 19:21:39 -0700 Subject: [PATCH 07/18] Fix issues with torch.compile --- test/quantization/test_quant_api.py | 18 +++++++----------- torchao/dtypes/uintx/int4_cpu_layout.py | 4 ++-- torchao/ops.py | 4 +++- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 2daaf0bd6c..d3d9409bf2 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -898,25 +898,21 @@ def test_8da4w_cpu(self, dtype, x_dim, bias): group_size=32, layout=Int8DynamicActInt4WeightCPULayout() ), ) - # y, code = torch._inductor.utils.run_and_get_code( - # torch.compile(m, fullgraph=True, dynamic=True), - # *example_inputs, - # ) - # # ensure the expected op is in the code - # assert "shift" in code[0] # unpacking int4 values - # assert "extern_kernels.mm" in code[0] - y = m(*example_inputs) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] quantize_( m2, int8_dynamic_activation_int4_weight( group_size=32, - # mapping_type=MappingType.ASYMMETRIC, layout=PlainLayout(), ), ) torch._dynamo.reset() # may segfault without this - # y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) - y2 = m2(*example_inputs) + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) atol, rtol = 1e-7, 1e-5 if dtype == torch.bfloat16: atol, rtol = 0.01, 1e-3 diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index ad6d7d1d8d..73f992b9e7 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -409,7 +409,7 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_weight, scales, qzeros = ( + packed_weight, scales, qzeros, compensation = ( tensor_data_dict["packed_weight"], tensor_data_dict["scales"], tensor_data_dict["qzeros"], @@ -419,7 +419,7 @@ def __tensor_unflatten__( transposed, _layout, ) = tensor_attributes - return cls(packed_weight, scales, qzeros, transposed, _layout) + return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) @classmethod def from_plain( diff --git a/torchao/ops.py b/torchao/ops.py index 23c6f41985..c0bd556bb4 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1040,4 +1040,6 @@ def _( bias: Optional[Tensor], out_dtype: torch.dtype, ) -> Tensor: - return input.new_empty(*input.shape[:-1], weight.shape[0], dtype=out_dtype) + assert weight.dim() == 4 + N = weight.size(0) * weight.size(3) * 2 + return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) From 8c5eebb5df1c7ca4b5dc7e110ce3118a1af4c3e1 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sun, 25 May 2025 22:43:22 -0700 Subject: [PATCH 08/18] Fix DA8W4CPUAQTTensorImpl.get_plain --- torchao/dtypes/uintx/int4_cpu_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 5bc7e5be77..ed4cab423f 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -494,7 +494,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: K = packed_w_shape[1] x = torch.eye(K).to(torch.uint8) x_scale = torch.ones(K).float() - x_qzero = torch.zeros(K).to(torch.int8) + x_qzero = torch.zeros(K).to(torch.int32) w_scale = torch.ones_like(self.scales).float() w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( From 2a26e15c7e10180d7e9e28de89f83494f9982669 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sun, 25 May 2025 22:58:14 -0700 Subject: [PATCH 09/18] Test DA8W4CPUAQTTensorImpl.get_plain in UT --- test/quantization/test_quant_api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 88abf349dd..721a4b041d 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -919,6 +919,13 @@ def test_8da4w_cpu(self, dtype, x_dim, bias): elif dtype == torch.half: atol, rtol = 0.005, 1e-3 assert torch.allclose(y, y2, atol=atol, rtol=rtol) + # Test get_plain by dequantize() + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() + assert torch.allclose(dqw1, dqw1_ref) + assert torch.allclose(dqw2, dqw2_ref) # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") From 369000f0a0d492e00fc9eca4d157eda95fb008ec Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 26 May 2025 23:05:09 +0000 Subject: [PATCH 10/18] Skip UT if CPP kernel not built --- test/quantization/test_quant_api.py | 4 ++++ torchao/csrc/cpu/da8w4_linear.cpp | 3 +++ 2 files changed, 7 insertions(+) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 721a4b041d..ba9cbe5d41 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -877,6 +877,10 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), + reason="cpp kernels not built", + ) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index 65cbb2cd64..8d2e340407 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -513,6 +513,9 @@ void _da8w4_linear_impl( } } }); + if constexpr (use_cpublas) { + at::native::cpublas::brgemm_release(); + } } at::Tensor da8w4_linear_impl( From f6e87ba06cd4e4ebafeb3a740e3f583de97645a6 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 27 May 2025 19:35:06 +0000 Subject: [PATCH 11/18] Add AVX512_VNNI implementation for small M --- setup.py | 6 ++ test/quantization/test_quant_api.py | 5 +- torchao/csrc/cpu/da8w4_linear.cpp | 145 +++++++++++++++++++++++++++- 3 files changed, 149 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 0915f6ae1e..26810be4b3 100644 --- a/setup.py +++ b/setup.py @@ -325,6 +325,12 @@ def get_extensions(): "-fopenmp", ] ) + if torch._C._cpu._is_avx512_vnni_supported(): + extra_compile_args["cxx"].extend( + [ + "-DCPU_CAPABILITY_AVX512_VNNI", + ] + ) if debug_mode: extra_compile_args["cxx"].append("-g") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ba9cbe5d41..81a7ce783c 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -885,11 +885,12 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) - def test_8da4w_cpu(self, dtype, x_dim, bias): + @common_utils.parametrize("bs", [1, 160]) + def test_8da4w_cpu(self, dtype, x_dim, bias, bs): device = "cpu" m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) m2 = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=dtype, device=device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) if x_dim == 3: example_inputs = (example_inputs[0].unsqueeze(0),) diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index 8d2e340407..f8a2dc7394 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -25,6 +26,8 @@ namespace torchao { namespace { +#define BLOCK_N 32 + static bool use_cpublas_checked = false; static bool use_cpublas = false; @@ -72,7 +75,7 @@ da8w4_linear_prepack_impl( int G = scales.size(1); int group_size = K / G; int block_k = group_size > 128 ? 128 : group_size; - constexpr int block_n = 32; + constexpr int block_n = BLOCK_N; int Nc = N / block_n; int Kc = K / block_k; @@ -284,6 +287,116 @@ void _dequant_weight_zp_only( } #endif +#if defined(CPU_CAPABILITY_AVX512_VNNI) +inline __m512i combine_m256i(__m256i a, __m256i b) { + __m512i c = _mm512_castsi256_si512(a); + return _mm512_inserti64x4(c, b, 1); +} + +inline __m512i combine_m256i(std::array<__m256i, 2> two_256) { + return combine_m256i(two_256[0], two_256[1]); +} + +template +void _dequant_gemm_accum_small_M( + float* C, + const uint8_t* A, + const float* scales_a, + const int32_t* qzeros_a, + const uint8_t* B, + const float* scales_b, + const int8_t* qzeros_b, + const int32_t* compensation, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + constexpr int N = BLOCK_N; + constexpr int COLS = N / 16; + __m512i va; + __m512i vb[COLS]; + __m512i vc[M * COLS]; + __m512 vscales[COLS]; + __m512i vzps[COLS]; + __m512i vcompensate[COLS]; + + // Load scales and zps + c10::ForcedUnroll{}([&](auto i) { + vscales[i] = _mm512_loadu_ps(scales_b + i * 16); + vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16)); + }); + c10::ForcedUnroll{}( + [&](auto i) { vc[i] = _mm512_setzero_epi32(); }); + + auto compute = [&](auto i, int k) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k)); + } + + if constexpr (row == 0) { + vb[col] = combine_m256i(load_uint4_as_int8(B + k * ldb + col * 16 * 2)); + vb[col] = _mm512_sub_epi8(vb[col], vzps[col]); + vcompensate[col] = _mm512_loadu_epi32(compensation + col * 16); + _mm_prefetch(B + (k + 32) * ldb + col * 16 * 2, _MM_HINT_T0); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + + // Accumulate along k + constexpr const int unroll = 4; + int k = 0; + for (; k < K / 4 / unroll; k++) { + c10::ForcedUnroll{}([&](auto i) { + c10::ForcedUnroll{}(compute, 4 * (k * unroll + i)); + }); + } + k *= 4 * unroll; + for (; k < K; k += 4) { + c10::ForcedUnroll{}(compute, k); + } + + // Store to C + auto store = [&](auto i) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + // compute (qC - compensate * zp_a) * scale_a * scale_b + __m512 vc_float; + vc[i] = _mm512_sub_epi32( + vc[i], + _mm512_mullo_epi32( + vcompensate[col], _mm512_set1_epi32(*(qzeros_a + row)))); + vc_float = _mm512_cvtepi32_ps(vc[i]); + vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row))); + + vc_float = _mm512_mul_ps(vc_float, vscales[col]); + auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16); + vc_float = _mm512_add_ps(vc_float, vc_old); + _mm512_storeu_ps(C + row * ldc + col * 16, vc_float); + }; + c10::ForcedUnroll{}(store); + +} + +#define call_dequant_gemm_accum_small_M(M) \ + _dequant_gemm_accum_small_M( \ + C, \ + A, \ + scales_a, \ + qzeros_a, \ + B, \ + scales_b, \ + qzeros_b, \ + compensation, \ + K, \ + lda, \ + ldb, \ + ldc); +#endif + template void _dequant_gemm_accum( float* C, @@ -302,6 +415,25 @@ void _dequant_gemm_accum( int64_t ldc) { // Compute GEMM int8 * int8 -> int32 // dequant result to float by applying scales/qzeros + auto tid = at::get_thread_num(); +#if defined(CPU_CAPABILITY_AVX512_VNNI) + if (M <= 4 && N == BLOCK_N) { + switch (M) { + case 1: + call_dequant_gemm_accum_small_M(1); + return; + case 2: + call_dequant_gemm_accum_small_M(2); + return; + case 3: + call_dequant_gemm_accum_small_M(3); + return; + case 4: + call_dequant_gemm_accum_small_M(4); + return; + } + } +#endif int8_t dqB[K * N]; _dequant_weight_zp_only(B, dqB, qzeros_b, N, K, ldb); @@ -320,6 +452,8 @@ void _dequant_gemm_accum( dqB, C_i32, true /* is_vnni */); + _mm_prefetch(B + N * K / 2, _MM_HINT_T0); + _mm_prefetch(A + K, _MM_HINT_T0); _dequant_and_store( C, C_i32, @@ -444,7 +578,8 @@ void _da8w4_linear_impl( int64_t Nc = weight.size(0); int64_t Kc = weight.size(1); int64_t block_k = weight.size(2); - int64_t block_n = weight.size(3) * 2; + constexpr int64_t block_n = BLOCK_N; + TORCH_CHECK(weight.size(3) * 2 == block_n, "DA8W4: unexpected weight shape"); int64_t N = Nc * block_n; TORCH_CHECK(K == Kc * block_k, "DA8W4: weight and input shapes mismatch"); int64_t block_m = [&]() -> long { @@ -512,10 +647,10 @@ void _da8w4_linear_impl( store_out(y_buf[0], c_ptr + mci * block_m * N + nc * block_n, m_size, block_n, N); } } + if constexpr (use_cpublas) { + at::native::cpublas::brgemm_release(); + } }); - if constexpr (use_cpublas) { - at::native::cpublas::brgemm_release(); - } } at::Tensor da8w4_linear_impl( From 0a87ef06e2e7063f15b043655adcca6b9f30600a Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 3 Jun 2025 16:55:01 +0000 Subject: [PATCH 12/18] improve performance --- torchao/csrc/cpu/da8w4_linear.cpp | 128 +++++++++++++++--------------- 1 file changed, 66 insertions(+), 62 deletions(-) diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index f8a2dc7394..f8f125695e 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -196,17 +196,17 @@ inline std::array<__m256i, 2> load_uint4_as_int8(const uint8_t* __restrict__ qB) return {low, high}; } +template void _dequant_weight_zp_only( const uint8_t* __restrict__ B, int8_t* dqB, const int8_t* __restrict__ qzeros, - int64_t N, - int64_t K, - int64_t ldb) { + int64_t K) { // unpack weight int8 -> two int4 // subtract zero point // B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4] // dqB shape = [K, N], actual shape = [K / 4, N, 4] +#pragma GCC unroll 2 for (int n = 0; n < N; n += 16) { auto [zps_low, zps_high] = load_zps_4vnni(&qzeros[n]); for (int k = 0; k < K; k += 4) { @@ -220,7 +220,7 @@ void _dequant_weight_zp_only( } } -template +template void _dequant_and_store( float* __restrict__ output, const int32_t* __restrict__ input, @@ -229,17 +229,16 @@ void _dequant_and_store( const float* __restrict__ scale_b, const int32_t* __restrict__ comp_b, int M, - int N, int ldi, int ldo, int ldsa = 1) { -#pragma GCC unroll 2 for (int m = 0; m < M; ++m) { float a_scale = *(scale_a + m * ldsa); int32_t a_zp = *(zp_a + m * ldsa); __m512 va_scale = _mm512_set1_ps(a_scale); __m512i va_zp = _mm512_set1_epi32(a_zp); int n = 0; +#pragma GCC unroll 2 for (; n < N; n += 16) { __m512i va = _mm512_loadu_si512(input + m * ldi + n); __m512i vb_comp = _mm512_loadu_si512(comp_b + n); @@ -268,13 +267,12 @@ void _dequant_and_store( } #else +template void _dequant_weight_zp_only( const uint8_t* B, int8_t* dqB, const int8_t* qzeros, - int64_t N, - int64_t K, - int64_t ldb) { + int64_t K) { // B shape = [K, N / 2] // dqB shape = [K, N] for (int k = 0; k < K; ++k) { @@ -297,23 +295,23 @@ inline __m512i combine_m256i(std::array<__m256i, 2> two_256) { return combine_m256i(two_256[0], two_256[1]); } -template +template void _dequant_gemm_accum_small_M( - float* C, + float* __restrict__ C, const uint8_t* A, const float* scales_a, const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, const int8_t* qzeros_b, - const int32_t* compensation, int64_t K, int64_t lda, - int64_t ldb, int64_t ldc) { - constexpr int N = BLOCK_N; constexpr int COLS = N / 16; + // Computing compensation is faster than loading it for small M + // because it's memory bound. + __m512i ones = _mm512_set1_epi8(1); // used for computing compensation __m512i va; __m512i vb[COLS]; __m512i vc[M * COLS]; @@ -325,6 +323,7 @@ void _dequant_gemm_accum_small_M( c10::ForcedUnroll{}([&](auto i) { vscales[i] = _mm512_loadu_ps(scales_b + i * 16); vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16)); + vcompensate[i] = _mm512_setzero_epi32(); }); c10::ForcedUnroll{}( [&](auto i) { vc[i] = _mm512_setzero_epi32(); }); @@ -338,10 +337,12 @@ void _dequant_gemm_accum_small_M( } if constexpr (row == 0) { - vb[col] = combine_m256i(load_uint4_as_int8(B + k * ldb + col * 16 * 2)); + int B_offset = k * ldb + col * 16 * 2; + vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset)); vb[col] = _mm512_sub_epi8(vb[col], vzps[col]); - vcompensate[col] = _mm512_loadu_epi32(compensation + col * 16); - _mm_prefetch(B + (k + 32) * ldb + col * 16 * 2, _MM_HINT_T0); + vcompensate[col] = + _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]); + _mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0); } vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); }; @@ -381,8 +382,8 @@ void _dequant_gemm_accum_small_M( } -#define call_dequant_gemm_accum_small_M(M) \ - _dequant_gemm_accum_small_M( \ +#define call_dequant_gemm_accum_small_M(M, N, ldb) \ + _dequant_gemm_accum_small_M( \ C, \ A, \ scales_a, \ @@ -390,14 +391,12 @@ void _dequant_gemm_accum_small_M( B, \ scales_b, \ qzeros_b, \ - compensation, \ K, \ lda, \ - ldb, \ ldc); #endif -template +template void _dequant_gemm_accum( float* C, const uint8_t* A, @@ -408,35 +407,33 @@ void _dequant_gemm_accum( const int8_t* qzeros_b, const int32_t* compensation, int64_t M, - int64_t N, int64_t K, int64_t lda, - int64_t ldb, int64_t ldc) { // Compute GEMM int8 * int8 -> int32 // dequant result to float by applying scales/qzeros - auto tid = at::get_thread_num(); #if defined(CPU_CAPABILITY_AVX512_VNNI) - if (M <= 4 && N == BLOCK_N) { + if (M <= 4) { switch (M) { case 1: - call_dequant_gemm_accum_small_M(1); + call_dequant_gemm_accum_small_M(1, N, ldb); + // CALL_MICRO_GEMM_KERNEL(1); return; case 2: - call_dequant_gemm_accum_small_M(2); + call_dequant_gemm_accum_small_M(2, N, ldb); return; case 3: - call_dequant_gemm_accum_small_M(3); + call_dequant_gemm_accum_small_M(3, N, ldb); return; case 4: - call_dequant_gemm_accum_small_M(4); + call_dequant_gemm_accum_small_M(4, N, ldb); return; } } #endif int8_t dqB[K * N]; - _dequant_weight_zp_only(B, dqB, qzeros_b, N, K, ldb); + _dequant_weight_zp_only(B, dqB, qzeros_b, K); #if defined(CPU_CAPABILITY_AVX512) if constexpr (use_cpublas) { int32_t C_i32[M * N]; @@ -454,7 +451,7 @@ void _dequant_gemm_accum( true /* is_vnni */); _mm_prefetch(B + N * K / 2, _MM_HINT_T0); _mm_prefetch(A + K, _MM_HINT_T0); - _dequant_and_store( + _dequant_and_store( C, C_i32, scales_a, @@ -462,7 +459,6 @@ void _dequant_gemm_accum( scales_b, compensation, M, - N, N /*ldi*/, ldc, 1 /*ldsa*/); @@ -481,71 +477,77 @@ void _dequant_gemm_accum( } } -inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m, int64_t n) { +template +inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) { if (bias_ptr) { for (int i = 0; i < m; ++i) { int j = 0; #if defined(CPU_CAPABILITY_AVX512) - for (; j < n; j += 16) { +#pragma GCC unroll 2 + for (; j < N; j += 16) { __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); - _mm512_storeu_ps(y_buf + i * n + j, bias_vec); + _mm512_storeu_ps(y_buf + i * N + j, bias_vec); } #endif - for (; j < n; ++j) { - y_buf[i * n + j] = bias_ptr[j]; + for (; j < N; ++j) { + y_buf[i * N + j] = bias_ptr[j]; } } } else { // initialize to zero for (int i = 0; i < m; ++i) { int j = 0; #if defined(CPU_CAPABILITY_AVX512) - for (; j < n; j += 16) { +#pragma GCC unroll 2 + for (; j < N; j += 16) { __m512 zero_vec = _mm512_setzero_ps(); - _mm512_storeu_ps(y_buf + i * n + j, zero_vec); + _mm512_storeu_ps(y_buf + i * N + j, zero_vec); } #endif - for (; j < n; ++j) { - y_buf[i * n + j] = 0; + for (; j < N; ++j) { + y_buf[i * N + j] = 0; } } } } -template -inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, int64_t n, int64_t lda) { +template +inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_t n, */ int64_t lda) { for (int i = 0; i < m; ++i) { int j = 0; if constexpr (std::is_same::value) { #if defined(CPU_CAPABILITY_AVX512) - for (; j < n; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * n + j); +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); } #endif - for (; j < n; ++j) { - c_ptr[i * lda + j] = y_buf[i * n + j]; + for (; j < N; ++j) { + c_ptr[i * lda + j] = y_buf[i * N + j]; } } else if constexpr (std::is_same::value) { #if defined(CPU_CAPABILITY_AVX512) - for (; j < n; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * n + j); +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); } #endif - for (; j < n; ++j) { - c_ptr[i * lda + j] = at::BFloat16(y_buf[i * n + j]); + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]); } } else if constexpr (std::is_same::value) { #if defined(CPU_CAPABILITY_AVX512) - for (; j < n; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * n + j); +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); } #endif - for (; j < n; ++j) { - c_ptr[i * lda + j] = at::Half(y_buf[i * n + j]); + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]); } } else { TORCH_CHECK(false, "Unsupported output dtype"); @@ -625,9 +627,9 @@ void _da8w4_linear_impl( alignas(64) float y_buf[m_size][block_n]; // copy bias to y_buf if bias is not None auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; - copy_bias(bias_data, y_buf[0], m_size, block_n); + copy_bias(bias_data, y_buf[0], m_size); for (int kci = 0; kci < Kc; ++kci) { - _dequant_gemm_accum( + _dequant_gemm_accum( y_buf[0] /*C*/, a_ptr + mci * block_m * K + kci * block_k /*A*/, a_scales_ptr + mci * block_m /*scales_a*/, @@ -637,14 +639,16 @@ void _da8w4_linear_impl( b_qzeros_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*qzeros_b*/, compensation_ptr + nc * block_n * Kc + kci * block_n /*compensation*/, m_size /*M*/, - block_n /*N*/, block_k /*K*/, K /*lda*/, - block_n / 2 /*ldb*/, block_n /*ldc*/); } - // store y_buf to output - store_out(y_buf[0], c_ptr + mci * block_m * N + nc * block_n, m_size, block_n, N); + // store y_buf to output with dtype conversion + store_out( + y_buf[0], + c_ptr + mci * block_m * N + nc * block_n, + m_size, + N /*lda*/); } } if constexpr (use_cpublas) { From e05b96a4bd532acf6404cd7906990ae1e0ff81be Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 4 Jun 2025 14:02:14 +0000 Subject: [PATCH 13/18] Support symmetric quantization of activation --- test/quantization/test_quant_api.py | 21 ++- torchao/csrc/cpu/da8w4_linear.cpp | 165 ++++++++++++++++-------- torchao/dtypes/uintx/int4_cpu_layout.py | 16 ++- 3 files changed, 139 insertions(+), 63 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 81a7ce783c..b4a8cef5c0 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -886,7 +886,11 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) @common_utils.parametrize("bs", [1, 160]) - def test_8da4w_cpu(self, dtype, x_dim, bias, bs): + @common_utils.parametrize("sym_quant_a", [True, False]) + def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): + if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8: + # not supported until PT 2.8 + return device = "cpu" m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) m2 = copy.deepcopy(m) @@ -900,7 +904,11 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs): quantize_( m, int8_dynamic_activation_int4_weight( - group_size=32, layout=Int8DynamicActInt4WeightCPULayout() + group_size=32, + layout=Int8DynamicActInt4WeightCPULayout(), + act_mapping_type=MappingType.SYMMETRIC + if sym_quant_a + else MappingType.ASYMMETRIC, ), ) y, code = torch._inductor.utils.run_and_get_code( @@ -914,15 +922,18 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs): int8_dynamic_activation_int4_weight( group_size=32, layout=PlainLayout(), + act_mapping_type=MappingType.SYMMETRIC + if sym_quant_a + else MappingType.ASYMMETRIC, ), ) torch._dynamo.reset() # may segfault without this y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) - atol, rtol = 1e-7, 1e-5 + atol, rtol = 4e-7, 1e-5 if dtype == torch.bfloat16: - atol, rtol = 0.01, 1e-3 + atol, rtol = 1e-2, 3e-3 elif dtype == torch.half: - atol, rtol = 0.005, 1e-3 + atol, rtol = 6e-3, 2e-3 assert torch.allclose(y, y2, atol=atol, rtol=rtol) # Test get_plain by dequantize() dqw1 = m.linear1.weight.original_weight_tensor.dequantize() diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index f8f125695e..d21c8f3d93 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -139,6 +139,19 @@ da8w4_linear_prepack_impl( return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales), std::move(blocked_qzeros), std::move(compensation)); } +template +struct ActDtype; +template<> +struct ActDtype { + using type = int8_t; +}; + +template<> +struct ActDtype { + using type = uint8_t; +}; + + #if defined(CPU_CAPABILITY_AVX512) inline std::array<__m256i, 2> load_zps_4vnni(const int8_t* __restrict__ zps) { // broadcast 01234567 to @@ -220,7 +233,7 @@ void _dequant_weight_zp_only( } } -template +template void _dequant_and_store( float* __restrict__ output, const int32_t* __restrict__ input, @@ -234,15 +247,21 @@ void _dequant_and_store( int ldsa = 1) { for (int m = 0; m < M; ++m) { float a_scale = *(scale_a + m * ldsa); - int32_t a_zp = *(zp_a + m * ldsa); __m512 va_scale = _mm512_set1_ps(a_scale); - __m512i va_zp = _mm512_set1_epi32(a_zp); + int32_t a_zp; + __m512i va_zp; + if constexpr (!sym_quant_a) { + a_zp = *(zp_a + m * ldsa); + va_zp = _mm512_set1_epi32(a_zp); + } int n = 0; #pragma GCC unroll 2 for (; n < N; n += 16) { - __m512i va = _mm512_loadu_si512(input + m * ldi + n); - __m512i vb_comp = _mm512_loadu_si512(comp_b + n); - __m512i vc = _mm512_sub_epi32(va, _mm512_mullo_epi32(vb_comp, va_zp)); + __m512i vc = _mm512_loadu_si512(input + m * ldi + n); + if constexpr (!sym_quant_a) { + __m512i vb_comp = _mm512_loadu_si512(comp_b + n); + vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp)); + } __m512 vc_f = _mm512_cvtepi32_ps(vc); __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); __m512 vb_s = _mm512_loadu_ps(scale_b + n); @@ -255,8 +274,13 @@ void _dequant_and_store( } } for (; n < N; ++n) { - float dq_val = + float dq_val; + if constexpr (sym_quant_a) { + dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n]; + } else { + dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale * scale_b[n]; + } if constexpr (accum) { output[m * ldo + n] += dq_val; } else { @@ -295,7 +319,14 @@ inline __m512i combine_m256i(std::array<__m256i, 2> two_256) { return combine_m256i(two_256[0], two_256[1]); } -template +// negate elements in a according to b's sign +static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { + __m512i zero = _mm512_setzero_si512(); + __mmask64 blt0 = _mm512_movepi8_mask(b); + return _mm512_mask_sub_epi8(a, blt0, zero, a); +} + +template void _dequant_gemm_accum_small_M( float* __restrict__ C, const uint8_t* A, @@ -307,6 +338,7 @@ void _dequant_gemm_accum_small_M( int64_t K, int64_t lda, int64_t ldc) { + // if sym_quant_a is true, A pointer type is passed in as uint8_t* but actually int8_t*. constexpr int COLS = N / 16; // Computing compensation is faster than loading it for small M @@ -323,7 +355,9 @@ void _dequant_gemm_accum_small_M( c10::ForcedUnroll{}([&](auto i) { vscales[i] = _mm512_loadu_ps(scales_b + i * 16); vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16)); - vcompensate[i] = _mm512_setzero_epi32(); + if constexpr (!sym_quant_a) { + vcompensate[i] = _mm512_setzero_epi32(); + } }); c10::ForcedUnroll{}( [&](auto i) { vc[i] = _mm512_setzero_epi32(); }); @@ -340,11 +374,19 @@ void _dequant_gemm_accum_small_M( int B_offset = k * ldb + col * 16 * 2; vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset)); vb[col] = _mm512_sub_epi8(vb[col], vzps[col]); + if constexpr (!sym_quant_a) { vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]); + } _mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0); } - vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + if constexpr (sym_quant_a) { + auto vsb = _mm512_sign_epi8(vb[col], va); + auto vabsa = _mm512_sign_epi8(va, va); + vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb); + } else { + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + } }; // Accumulate along k @@ -366,10 +408,12 @@ void _dequant_gemm_accum_small_M( constexpr const int col = i % COLS; // compute (qC - compensate * zp_a) * scale_a * scale_b __m512 vc_float; - vc[i] = _mm512_sub_epi32( - vc[i], - _mm512_mullo_epi32( - vcompensate[col], _mm512_set1_epi32(*(qzeros_a + row)))); + if constexpr (!sym_quant_a) { + vc[i] = _mm512_sub_epi32( + vc[i], + _mm512_mullo_epi32( + vcompensate[col], _mm512_set1_epi32(*(qzeros_a + row)))); + } vc_float = _mm512_cvtepi32_ps(vc[i]); vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row))); @@ -382,8 +426,8 @@ void _dequant_gemm_accum_small_M( } -#define call_dequant_gemm_accum_small_M(M, N, ldb) \ - _dequant_gemm_accum_small_M( \ +#define call_dequant_gemm_accum_small_M(M) \ + _dequant_gemm_accum_small_M( \ C, \ A, \ scales_a, \ @@ -396,7 +440,7 @@ void _dequant_gemm_accum_small_M( ldc); #endif -template +template void _dequant_gemm_accum( float* C, const uint8_t* A, @@ -416,17 +460,16 @@ void _dequant_gemm_accum( if (M <= 4) { switch (M) { case 1: - call_dequant_gemm_accum_small_M(1, N, ldb); - // CALL_MICRO_GEMM_KERNEL(1); + call_dequant_gemm_accum_small_M(1); return; case 2: - call_dequant_gemm_accum_small_M(2, N, ldb); + call_dequant_gemm_accum_small_M(2); return; case 3: - call_dequant_gemm_accum_small_M(3, N, ldb); + call_dequant_gemm_accum_small_M(3); return; case 4: - call_dequant_gemm_accum_small_M(4, N, ldb); + call_dequant_gemm_accum_small_M(4); return; } } @@ -434,6 +477,8 @@ void _dequant_gemm_accum( int8_t dqB[K * N]; _dequant_weight_zp_only(B, dqB, qzeros_b, K); + using Tin = typename ActDtype::type; + Tin* A_ptr = (Tin*)A; #if defined(CPU_CAPABILITY_AVX512) if constexpr (use_cpublas) { int32_t C_i32[M * N]; @@ -445,13 +490,13 @@ void _dequant_gemm_accum( N /*ldb*/, N /*ldc*/, false /* add_C */, - A, + A_ptr, dqB, C_i32, true /* is_vnni */); _mm_prefetch(B + N * K / 2, _MM_HINT_T0); _mm_prefetch(A + K, _MM_HINT_T0); - _dequant_and_store( + _dequant_and_store( C, C_i32, scales_a, @@ -469,7 +514,11 @@ void _dequant_gemm_accum( for (int64_t j = 0; j < N; ++j) { float sum = 0; for (int64_t k = 0; k < K; ++k) { - sum += ((int32_t)A[i * lda + k] - qzeros_a[i]) * (int32_t)dqB[k * N + j]; + if constexpr (sym_quant_a) { + sum += ((int32_t)A_ptr[i * lda + k] * dqB[k * N + j]); + } else { + sum += ((int32_t)A_ptr[i * lda + k] - qzeros_a[i]) * (int32_t)dqB[k * N + j]; + } } C[i * ldc + j] += sum * scales_a[i] * scales_b[j]; } @@ -555,7 +604,7 @@ inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_ } } -template +template void _da8w4_linear_impl( const at::Tensor& input, const at::Tensor& input_scales, @@ -606,13 +655,14 @@ void _da8w4_linear_impl( "DA8W4 CPU: group_size should be divisible by block_k"); int64_t block_per_group = group_size / block_k; - const uint8_t* a_ptr = input_view.data_ptr(); + using Tin = typename ActDtype::type; + const Tin* a_ptr = input_view.data_ptr(); const float* a_scales_ptr = input_scales.data_ptr(); - const int32_t* a_qzeros_ptr = input_qzeros.data_ptr(); + const int32_t* a_qzeros_ptr = sym_quant_a ? nullptr : input_qzeros.data_ptr(); const uint8_t* b_ptr = weight.data_ptr(); const float* b_scales_ptr = weight_scales.data_ptr(); const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr(); - const int32_t* compensation_ptr = compensation.data_ptr(); + const int32_t* compensation_ptr = sym_quant_a ? nullptr : compensation.data_ptr(); out_dtype* c_ptr = output.data_ptr(); const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; @@ -629,9 +679,9 @@ void _da8w4_linear_impl( auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; copy_bias(bias_data, y_buf[0], m_size); for (int kci = 0; kci < Kc; ++kci) { - _dequant_gemm_accum( + _dequant_gemm_accum( y_buf[0] /*C*/, - a_ptr + mci * block_m * K + kci * block_k /*A*/, + (uint8_t*)a_ptr + mci * block_m * K + kci * block_k /*A*/, a_scales_ptr + mci * block_m /*scales_a*/, a_qzeros_ptr + mci * block_m /*qzeros_a*/, b_ptr + (nc * Kc + kci) * block_n * block_k / 2 /*B*/, @@ -668,38 +718,39 @@ at::Tensor da8w4_linear_impl( const std::optional& bias, at::ScalarType output_dtype) { static bool use_cpublas = da8w4_can_pack_weight(); + bool sym_quant_a = input.scalar_type() == c10::kChar; auto out_sizes = input.sizes().vec(); int64_t N = weight.size(0) * weight.size(-1) * 2; out_sizes.back() = N; auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); - if (use_cpublas) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { - _da8w4_linear_impl( - input, - input_scales, - input_qzeros, - weight, - weight_scales, - weight_qzeros, - compensation, - bias, - output); + +#define call__da8w4_linear_impl(use_cpublas, sym_quant_act) \ + AT_DISPATCH_FLOATING_TYPES_AND2( \ + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { \ + _da8w4_linear_impl( \ + input, \ + input_scales, \ + input_qzeros, \ + weight, \ + weight_scales, \ + weight_qzeros, \ + compensation, \ + bias, \ + output); \ }); + + if (use_cpublas) { + if (sym_quant_a) { + call__da8w4_linear_impl(true, true); + } else { + call__da8w4_linear_impl(true, false); + } } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { - _da8w4_linear_impl( - input, - input_scales, - input_qzeros, - weight, - weight_scales, - weight_qzeros, - compensation, - bias, - output); - }); + if (sym_quant_a) { + call__da8w4_linear_impl(false, true); + } else { + call__da8w4_linear_impl(false, false); + } } return output; } diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index ed4cab423f..ae8f3ef893 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -25,6 +25,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, fill_defaults, ) @@ -537,6 +538,15 @@ def _aqt_is_uint8(aqt): ) +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max == 127 + ) + + def _aqt_is_int4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" return ( @@ -553,7 +563,7 @@ def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): and is_device(weight_tensor.device.type, "cpu") and (bias is None or is_device(bias.device.type, "cpu")) and isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_uint8(input_tensor) + and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) and _is_float(input_tensor.dtype) and isinstance(input_tensor._layout, PlainLayout) and isinstance(weight_tensor, AffineQuantizedTensor) @@ -568,6 +578,10 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): assert TORCH_VERSION_AT_LEAST_2_7, ( f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" ) + if _aqt_is_int8(input_tensor): + assert TORCH_VERSION_AT_LEAST_2_8, ( + f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" + ) assert is_device(input_tensor.device.type, "cpu"), ( f"For CPU device only but got: {input_tensor.device}" ) From 18335c6b8913c15cd0fa9601143350624e70fa94 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 4 Jun 2025 14:34:59 +0000 Subject: [PATCH 14/18] Refine code --- torchao/csrc/cpu/da8w4_linear.cpp | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index d21c8f3d93..d2395a1090 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -1,27 +1,8 @@ -#include -// #include -#include -#include +#include #include -#include -#include #include -#include -#include -#include #include -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - -#include -#include -#include -#include - namespace torchao { namespace { From 66ab77ff7125aa025d006c5917420d461fa074b3 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 5 Jun 2025 14:53:50 +0000 Subject: [PATCH 15/18] Refine code --- torchao/csrc/cpu/da8w4_linear.cpp | 43 ++++++++++++++----------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index d2395a1090..1df2c13fdf 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -9,20 +9,17 @@ namespace { #define BLOCK_N 32 -static bool use_cpublas_checked = false; -static bool use_cpublas = false; +static bool cpublas_checked = false; +static bool cpublas_can_pack = false; -bool da8w4_can_pack_weight() { -#if defined(CPU_CAPABILITY_AVX512) - if (use_cpublas_checked) { - return use_cpublas; +bool cpublas_could_pack() { + // the could_pack check requires AMX support implicitly + if (cpublas_checked) { + return cpublas_can_pack; } - use_cpublas = at::native::cpublas::could_pack(at::kByte); - use_cpublas_checked = true; - return use_cpublas; -#else - return false; -#endif + cpublas_can_pack = at::native::cpublas::could_pack(at::kByte); + cpublas_checked = true; + return cpublas_can_pack; } /* @@ -71,7 +68,7 @@ da8w4_linear_prepack_impl( at::Tensor compensation = weight_view.to(at::kInt).sub(8).sum(-1); compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt); - if (da8w4_can_pack_weight()) { + if (cpublas_could_pack()) { blocked_weight = at::empty({Nc, Kc, block_k, block_n / 2}, weight.options()); auto weight_ptr = weight_reordered.data_ptr(); auto blocked_weight_ptr = blocked_weight.data_ptr(); @@ -421,7 +418,7 @@ void _dequant_gemm_accum_small_M( ldc); #endif -template +template void _dequant_gemm_accum( float* C, const uint8_t* A, @@ -438,7 +435,7 @@ void _dequant_gemm_accum( // Compute GEMM int8 * int8 -> int32 // dequant result to float by applying scales/qzeros #if defined(CPU_CAPABILITY_AVX512_VNNI) - if (M <= 4) { + if (M <= 4 && cpublas_can_pack) { switch (M) { case 1: call_dequant_gemm_accum_small_M(1); @@ -461,7 +458,7 @@ void _dequant_gemm_accum( using Tin = typename ActDtype::type; Tin* A_ptr = (Tin*)A; #if defined(CPU_CAPABILITY_AVX512) - if constexpr (use_cpublas) { + if constexpr (cpublas_can_pack) { int32_t C_i32[M * N]; at::native::cpublas::brgemm( M, @@ -585,7 +582,7 @@ inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_ } } -template +template void _da8w4_linear_impl( const at::Tensor& input, const at::Tensor& input_scales, @@ -660,7 +657,7 @@ void _da8w4_linear_impl( auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; copy_bias(bias_data, y_buf[0], m_size); for (int kci = 0; kci < Kc; ++kci) { - _dequant_gemm_accum( + _dequant_gemm_accum( y_buf[0] /*C*/, (uint8_t*)a_ptr + mci * block_m * K + kci * block_k /*A*/, a_scales_ptr + mci * block_m /*scales_a*/, @@ -682,7 +679,7 @@ void _da8w4_linear_impl( N /*lda*/); } } - if constexpr (use_cpublas) { + if constexpr (cpublas_can_pack) { at::native::cpublas::brgemm_release(); } }); @@ -698,17 +695,17 @@ at::Tensor da8w4_linear_impl( const at::Tensor& compensation, const std::optional& bias, at::ScalarType output_dtype) { - static bool use_cpublas = da8w4_can_pack_weight(); + static bool cpublas_can_pack = cpublas_could_pack(); bool sym_quant_a = input.scalar_type() == c10::kChar; auto out_sizes = input.sizes().vec(); int64_t N = weight.size(0) * weight.size(-1) * 2; out_sizes.back() = N; auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); -#define call__da8w4_linear_impl(use_cpublas, sym_quant_act) \ +#define call__da8w4_linear_impl(cpublas_can_pack, sym_quant_act) \ AT_DISPATCH_FLOATING_TYPES_AND2( \ at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "da8w4_linear_cpu", [&] { \ - _da8w4_linear_impl( \ + _da8w4_linear_impl( \ input, \ input_scales, \ input_qzeros, \ @@ -720,7 +717,7 @@ at::Tensor da8w4_linear_impl( output); \ }); - if (use_cpublas) { + if (cpublas_can_pack) { if (sym_quant_a) { call__da8w4_linear_impl(true, true); } else { From 75fbd6f87d2b22a71e6ac275c32442d248952615 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 14 Jun 2025 17:05:28 +0000 Subject: [PATCH 16/18] Put in a separate file --- torchao/dtypes/affine_quantized_tensor_ops.py | 6 +- torchao/dtypes/uintx/__init__.py | 4 +- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 312 ++++++++++++++++++ torchao/dtypes/uintx/int4_cpu_layout.py | 291 +--------------- torchao/quantization/quant_api.py | 10 + 5 files changed, 330 insertions(+), 293 deletions(-) create mode 100644 torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 5f0da0ef08..a248dfffef 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -35,6 +35,10 @@ _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) +from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, +) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -42,8 +46,6 @@ from torchao.dtypes.uintx.int4_cpu_layout import ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, - _linear_int8_act_int4_weight_cpu_check, - _linear_int8_act_int4_weight_cpu_impl, ) from torchao.dtypes.uintx.int4_xpu_layout import ( _linear_bf16_act_uint4_weight_float_zero_check, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 7c681fd52c..6d1bc95653 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -4,9 +4,11 @@ from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) +from .dyn_int8_act_int4_wei_cpu_layout import ( + Int8DynamicActInt4WeightCPULayout, +) from .int4_cpu_layout import ( Int4CPULayout, - Int8DynamicActInt4WeightCPULayout, ) from .int4_xpu_layout import ( Int4XPULayout, diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py new file mode 100644 index 0000000000..9fb7079c6c --- /dev/null +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, +) + +from .int4_cpu_layout import ( + Int4CPUAQTTensorImpl, + _is_float, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Int8DynamicActInt4WeightCPULayout(Layout): + """Layout class for da8w4 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Int8DynamicActInt4WeightCPULayout) +class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): + """TensorImpl for da8w4 CPU layout for affine quantized tensor + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor + qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.qzeros = qzeros + self.compensation = compensation + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales", "qzeros", "compensation"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales, qzeros, compensation = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + tensor_data_dict["qzeros"], + tensor_data_dict["compensation"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) + assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" + assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + if zero_point.dim() == 1: + zero_point.unsqueeze_(-1) + + weight_int4, scales, qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) + ) + return cls(weight_int4, scales, qzeros, compensation, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + fn(self.qzeros), + fn(self.compensation), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = DA8W4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scales, + args[0].qzeros, + args[0].compensation, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + else: + return super().__torch_dispatch__(func, types, args, kwargs) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] * 2 + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.uint8) + x_scale = torch.ones(K).float() + x_qzero = torch.zeros(K).to(torch.int32) + w_scale = torch.ones_like(self.scales).float() + w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) + plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( + x, + x_scale, + x_qzero, + self.packed_weight, + w_scale, + w_qzero, + self.compensation, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.int8) + + if self.scales.dim() == 2: + assert self.qzeros.dim() == 2 + plain_scales = self.scales + plain_qzeros = self.qzeros + else: + assert self.scales.dim() == 3 and self.qzeros.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + plain_qzeros = ( + self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, plain_qzeros + + +def _aqt_is_uint8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 255 + ) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max == 127 + ) + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + TORCH_VERSION_AT_LEAST_2_7 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) + ) + + +def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert TORCH_VERSION_AT_LEAST_2_7, ( + f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" + ) + if _aqt_is_int8(input_tensor): + assert TORCH_VERSION_AT_LEAST_2_8, ( + f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + act_qzeros = act_mat.tensor_impl.zero_point + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + wei_qzeros = weight_tensor.tensor_impl.qzeros + compensation = weight_tensor.tensor_impl.compensation + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act.contiguous(), + act_scales, + act_qzeros, + packed_weight, + wei_scales, + wei_qzeros, + compensation, + bias, + orig_dtype, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index e32791fa37..be72d0e205 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -16,7 +16,7 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout, is_device +from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ( ZeroPointDomain, quantize_affine_tinygemm, @@ -24,8 +24,6 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, fill_defaults, ) @@ -340,290 +338,3 @@ def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): if bias is not None: y += bias return y.to(orig_dtype) - - -@dataclass(frozen=True) -class Int8DynamicActInt4WeightCPULayout(Layout): - """Layout class for da8w4 CPU layout for affine quantized tensor""" - - pass - - -@register_layout(Int8DynamicActInt4WeightCPULayout) -class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): - """TensorImpl for da8w4 CPU layout for affine quantized tensor - It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of - dimension: [n][k / 2] (uint8 dtype) - It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data - fields: - packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout - scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor - qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scales = scales - self.qzeros = qzeros - self.compensation = compensation - self.transposed = transposed - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scales", "qzeros", "compensation"], [ - self.transposed, - self._layout, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scales, qzeros, compensation = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scales"], - tensor_data_dict["qzeros"], - tensor_data_dict["compensation"], - ) - ( - transposed, - _layout, - ) = tensor_attributes - return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) - assert int_data.dtype == torch.int8, "DA8W4 CPU: expects uint8 weight" - assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" - # int8 -> uint8 - int_data = (int_data + 8).to(torch.uint8) - if scale.dim() == 1: - scale.unsqueeze_(-1) - scale = scale.to(torch.float) - if zero_point.dim() == 1: - zero_point.unsqueeze_(-1) - zero_point = zero_point.to(torch.int8) + 8 - - weight_int4, scales, qzeros, compensation = ( - torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) - ) - return cls(weight_int4, scales, qzeros, compensation, False, _layout) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scales), - fn(self.qzeros), - fn(self.compensation), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = DA8W4CPUAQTTensorImpl( - args[0].packed_weight, - args[0].scales, - args[0].qzeros, - args[0].compensation, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - else: - return super().__torch_dispatch__(func, types, args, kwargs) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @property - def block_size(self): - assert len(self.packed_weight.shape) == 2 - weight_shape = self.packed_weight.shape - N = weight_shape[0] - K = weight_shape[1] * 2 - groups = self.scales.numel() // N - group_size = K // groups - return (1, group_size) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Unpack weight by linear(eye(K), packed_weight).t() - packed_w_shape = self.packed_weight.shape - if len(packed_w_shape) == 4: - K = packed_w_shape[1] * packed_w_shape[2] - else: - K = packed_w_shape[1] - x = torch.eye(K).to(torch.uint8) - x_scale = torch.ones(K).float() - x_qzero = torch.zeros(K).to(torch.int32) - w_scale = torch.ones_like(self.scales).float() - w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) - plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( - x, - x_scale, - x_qzero, - self.packed_weight, - w_scale, - w_qzero, - self.compensation, - None, # bias - torch.float, # out_dtype - ) - plain_weight = plain_weight.t().contiguous() - plain_weight = plain_weight.to(torch.int8) - - if self.scales.dim() == 2: - assert self.qzeros.dim() == 2 - plain_scales = self.scales - plain_qzeros = self.qzeros - else: - assert self.scales.dim() == 3 and self.qzeros.dim() == 3 - packed_shape = self.scales.shape # [Nc, G, block_n] - plain_scales = ( - self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - plain_qzeros = ( - self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - - return plain_weight, plain_scales, plain_qzeros - - -def _aqt_is_uint8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 255 - ) - - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -127 - and aqt.quant_max == 127 - ) - - -def _aqt_is_int4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == -8 - and aqt.quant_max == 7 - ) - - -def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): - ret = ( - TORCH_VERSION_AT_LEAST_2_7 - and is_device(input_tensor.device.type, "cpu") - and is_device(weight_tensor.device.type, "cpu") - and (bias is None or is_device(bias.device.type, "cpu")) - and isinstance(input_tensor, AffineQuantizedTensor) - and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) - and _is_float(input_tensor.dtype) - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_int4(weight_tensor) - and _is_float(weight_tensor.dtype) - and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) - ) - return ret - - -def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_7, ( - f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" - ) - if _aqt_is_int8(input_tensor): - assert TORCH_VERSION_AT_LEAST_2_8, ( - f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" - ) - assert is_device(input_tensor.device.type, "cpu"), ( - f"For CPU device only but got: {input_tensor.device}" - ) - assert weight_tensor.block_size[0] == 1, ( - f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - act_mat = input_tensor - act = act_mat.tensor_impl.int_data - act_scales = act_mat.tensor_impl.scale - act_qzeros = act_mat.tensor_impl.zero_point - - packed_weight = weight_tensor.tensor_impl.packed_weight - wei_scales = weight_tensor.tensor_impl.scales - wei_qzeros = weight_tensor.tensor_impl.qzeros - compensation = weight_tensor.tensor_impl.compensation - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape to 2D - act = act.reshape(-1, act.shape[-1]) - - y = torch.ops.torchao.da8w4_linear_cpu.default( - act.contiguous(), - act_scales, - act_qzeros, - packed_weight, - wei_scales, - wei_qzeros, - compensation, - bias, - orig_dtype, # out_dtype - ) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - return y.to(orig_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d68827773d..18b6b06e08 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -788,6 +788,16 @@ def _int8_dynamic_activation_int4_weight_transform( ) elif isinstance(layout, CutlassInt4PackedLayout): weight = _int4_symm_cutlass_quant(weight) + elif isinstance(layout, Int8DynamicActInt4WeightCPULayout): + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype=torch.uint8, + quant_min=0, + quant_max=15, + _layout=layout, + ) else: weight = to_affine_quantized_intx( weight, From 4c0a7396bf28cead1ba82da414b3d681b6659256 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 25 Jun 2025 13:38:58 +0000 Subject: [PATCH 17/18] Bug fix --- torchao/csrc/cpu/da8w4_linear.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/da8w4_linear.cpp index 1df2c13fdf..537aa0fce9 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/da8w4_linear.cpp @@ -64,8 +64,10 @@ da8w4_linear_prepack_impl( at::Tensor blocked_weight; at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); at::Tensor blocked_qzeros = new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); - // weight was increased by 8 during quantization, so we need to subtract 8 - at::Tensor compensation = weight_view.to(at::kInt).sub(8).sum(-1); + // Compensation = Σ(k)(W[k][n] - ZP[n]) for each block. + auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) - new_qzeros.view({Nc, block_n, G, -1}); + weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, block_k}); + at::Tensor compensation = weight_sub_qzero.sum(-1); compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt); if (cpublas_could_pack()) { From e3731f720f2dd7da50f6ef37bbdbb53895fa5b6b Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 25 Jun 2025 14:47:49 +0000 Subject: [PATCH 18/18] refine code --- test/quantization/test_quant_api.py | 2 +- torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py | 2 +- torchao/quantization/quant_api.py | 5 ----- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a6d46d10cb..2bb20d5afd 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -723,7 +723,7 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): # is that the former packs two int4 weights into one int8, while the latter does not. quantize_( m, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=32, layout=Int8DynamicActInt4WeightCPULayout(), act_mapping_type=MappingType.SYMMETRIC diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 9fb7079c6c..ced7ec0dd8 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -300,7 +300,7 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): wei_scales, wei_qzeros, compensation, - bias, + bias.float() if bias is not None else bias, # requires bias to be float orig_dtype, # out_dtype ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f3824796b1..7287ae2bc0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -762,11 +762,6 @@ def _int8_dynamic_activation_int4_weight_transform( quant_min = -8 quant_max = 7 - if isinstance(layout, Int8DynamicActInt4WeightCPULayout): - # Int8DynamicActInt4WeightCPULayout requires bias to be in float32 - if module.bias is not None: - module.bias = torch.nn.Parameter(module.bias.float(), requires_grad=False) - # input settings if act_mapping_type == MappingType.ASYMMETRIC: if isinstance(layout, Int8DynamicActInt4WeightCPULayout):