From 2b1c17b10d6f407017abdf72f09c1bdfb1eda4ff Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 30 Sep 2024 14:40:27 -0700 Subject: [PATCH] Change arg order in lowbit linear ops to match aten (#982) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/982 Discussed with Manuel to align on arg order between CPU/MPS ops. Reviewed By: digantdesai, manuelcandales Differential Revision: D63422524 --- .../op_linear_8bit_act_xbit_weight-impl.h | 20 ++-- .../op_linear_8bit_act_xbit_weight_aten.cpp | 95 +++++++++++-------- .../w1s.cpp | 29 ++++++ .../w1sz.cpp | 29 ++++++ .../w2s.cpp | 8 +- .../w2sz.cpp | 8 +- .../w3s.cpp | 8 +- .../w3sz.cpp | 8 +- .../w4s.cpp | 8 +- .../w4sz.cpp | 8 +- .../w5s.cpp | 8 +- .../w5sz.cpp | 8 +- torchao/experimental/quant_api.py | 18 ++-- ...t_linear_8bit_act_xbit_weight_quantizer.py | 1 - 14 files changed, 164 insertions(+), 92 deletions(-) create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1s.cpp create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1sz.cpp diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 51a02d264c..80772c7c10 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -218,14 +218,14 @@ Tensor pack_weights_with_zeros_meta( #if defined(USE_ATEN) || defined(USE_EXECUTORCH) template Tensor linear_out_cpu( + const Tensor& activations, const Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to // int64_t when supported by AOTI Currently they are tensors with size // equal to (0, the int they wrap) + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { int n = n_tensor.size(1); int k = k_tensor.size(1); @@ -307,21 +307,21 @@ Tensor linear_out_cpu( #ifdef USE_ATEN template Tensor linear_cpu( + const Tensor& activations, const Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to // int64_t when supported by AOTI Currently they are tensors with size // equal to (0, the int they wrap) - const Tensor& n_tensor, - const Tensor& k_tensor, const Tensor& group_size_tensor, - const Tensor& activations) { + const Tensor& n_tensor, + const Tensor& k_tensor) { Tensor output_tensor = torch::empty({}, torch::kFloat32); linear_out_cpu( + activations, packed_weights, + group_size_tensor, n_tensor, k_tensor, - group_size_tensor, - activations, output_tensor); return output_tensor; } @@ -330,14 +330,14 @@ Tensor linear_cpu( #ifdef USE_ATEN template Tensor linear_meta( + const Tensor& activations, const Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to // int64_t when supported by AOTI // Currently they are tensors with size equal to (0, the int they wrap) - const Tensor& n_tensor, - const Tensor& k_tensor, const Tensor& group_size_tensor, - const Tensor& activations) { + const Tensor& n_tensor, + const Tensor& k_tensor) { int n = n_tensor.size(1); int k = k_tensor.size(1); CHECK_MSG(n >= 1, "n must be >= 1"); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index e098bf4477..6c4d600e6f 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -6,53 +6,62 @@ #include -#define DEFINE_OP(weight_nbit) \ - m.def( \ - "_pack_weights_a8sz_w" #weight_nbit \ - "s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); \ - m.def( \ - "_pack_weights_a8sz_w" #weight_nbit \ - "sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); \ - m.def( \ - "_linear_a8sz_w" #weight_nbit \ - "s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); \ - m.def( \ - "_linear_a8sz_w" #weight_nbit \ - "sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); \ - m.def( \ - "_linear_a8sz_w" #weight_nbit \ - "s.out(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations, *, Tensor(a!) out) -> Tensor(a!)"); \ - m.def( \ - "_linear_a8sz_w" #weight_nbit \ - "sz.out(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations, *, Tensor(a!) out) -> Tensor(a!)") +#define DEFINE_OP(weight_nbit) \ + m.def( \ + "_pack_8bit_act_" #weight_nbit \ + "bit0zp_weight(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); \ + m.def( \ + "_pack_8bit_act_" #weight_nbit \ + "bit_weight(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); \ + m.def( \ + "_linear_8bit_act_" #weight_nbit \ + "bit0zp_weight(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k) -> Tensor"); \ + m.def( \ + "_linear_8bit_act_" #weight_nbit \ + "bit_weight(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k) -> Tensor"); \ + m.def( \ + "_linear_8bit_act_" #weight_nbit \ + "bit0zp_weight.out(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k, *, Tensor(a!) out) -> Tensor(a!)"); \ + m.def( \ + "_linear_8bit_act_" #weight_nbit \ + "bit_weight.out(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k, *, Tensor(a!) out) -> Tensor(a!)") -#define DEFINE_CPU_IMPL(weight_nbit) \ - m.impl( \ - "_pack_weights_a8sz_w" #weight_nbit "s", \ - &pack_weights_without_zeros_cpu); \ - m.impl( \ - "_pack_weights_a8sz_w" #weight_nbit "sz", \ - &pack_weights_with_zeros_cpu); \ - m.impl("_linear_a8sz_w" #weight_nbit "s", &linear_cpu); \ - m.impl("_linear_a8sz_w" #weight_nbit "sz", &linear_cpu); \ - m.impl( \ - "_linear_a8sz_w" #weight_nbit "s.out", \ - &linear_out_cpu); \ - m.impl( \ - "_linear_a8sz_w" #weight_nbit "sz.out", \ +#define DEFINE_CPU_IMPL(weight_nbit) \ + m.impl( \ + "_pack_8bit_act_" #weight_nbit "bit0zp_weight", \ + &pack_weights_without_zeros_cpu); \ + m.impl( \ + "_pack_8bit_act_" #weight_nbit "bit_weight", \ + &pack_weights_with_zeros_cpu); \ + m.impl( \ + "_linear_8bit_act_" #weight_nbit "bit0zp_weight", \ + &linear_cpu); \ + m.impl( \ + "_linear_8bit_act_" #weight_nbit "bit_weight", \ + &linear_cpu); \ + m.impl( \ + "_linear_8bit_act_" #weight_nbit "bit0zp_weight.out", \ + &linear_out_cpu); \ + m.impl( \ + "_linear_8bit_act_" #weight_nbit "bit_weight.out", \ &linear_out_cpu) -#define DEFINE_META_IMPL(weight_nbit) \ - m.impl( \ - "_pack_weights_a8sz_w" #weight_nbit "s", \ - &pack_weights_without_zeros_meta); \ - m.impl( \ - "_pack_weights_a8sz_w" #weight_nbit "sz", \ - &pack_weights_with_zeros_meta); \ - m.impl("_linear_a8sz_w" #weight_nbit "s", &linear_meta); \ - m.impl("_linear_a8sz_w" #weight_nbit "sz", &linear_meta); +#define DEFINE_META_IMPL(weight_nbit) \ + m.impl( \ + "_pack_8bit_act_" #weight_nbit "bit0zp_weight", \ + &pack_weights_without_zeros_meta); \ + m.impl( \ + "_pack_8bit_act_" #weight_nbit "bit_weight", \ + &pack_weights_with_zeros_meta); \ + m.impl( \ + "_linear_8bit_act_" #weight_nbit "bit0zp_weight", \ + &linear_meta); \ + m.impl( \ + "_linear_8bit_act_" #weight_nbit "bit_weight", \ + &linear_meta); TORCH_LIBRARY(torchao, m) { + DEFINE_OP(1); DEFINE_OP(2); DEFINE_OP(3); DEFINE_OP(4); @@ -60,6 +69,7 @@ TORCH_LIBRARY(torchao, m) { } TORCH_LIBRARY_IMPL(torchao, CPU, m) { + DEFINE_CPU_IMPL(1); DEFINE_CPU_IMPL(2); DEFINE_CPU_IMPL(3); DEFINE_CPU_IMPL(4); @@ -67,6 +77,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { } TORCH_LIBRARY_IMPL(torchao, Meta, m) { + DEFINE_META_IMPL(1); DEFINE_META_IMPL(2); DEFINE_META_IMPL(3); DEFINE_META_IMPL(4); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1s.cpp new file mode 100644 index 0000000000..2ce4c42b51 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1s.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_1bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1sz.cpp new file mode 100644 index 0000000000..6767def428 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1sz.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_1bit_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2s.cpp index 06217c3d88..5be7e87338 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2s.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2s.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w2s.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_2bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp index a539781904..f6774308e2 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w2sz.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_2bit_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3s.cpp index 2bfe0b9386..d12043f243 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3s.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3s.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w3s.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_3bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3sz.cpp index 8045c15901..0bf2407dbb 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3sz.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3sz.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w3sz.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_3bit_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp index 5c6947f1be..709a3eccdf 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w4s.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_4bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4sz.cpp index 35776bf271..c1dcd371b6 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4sz.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4sz.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w4sz.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_4bit_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5s.cpp index e2b250dfa7..af732f4bb7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5s.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5s.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w5s.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_5bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5sz.cpp index 6578224b00..8b61b5199c 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5sz.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w5sz.cpp @@ -13,17 +13,17 @@ namespace { Tensor _op_out( RuntimeContext& ctx, + const Tensor& activations, const Tensor& packed_weights, + const Tensor& group_size_tensor, const Tensor& n_tensor, const Tensor& k_tensor, - const Tensor& group_size_tensor, - const Tensor& activations, Tensor& out) { (void)ctx; linear_out_cpu( - packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out); + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); return out; } } // namespace -EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w5sz.out", _op_out); +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_5bit_weight.out", _op_out); diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 531d7efaaa..ac21c75221 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -84,7 +84,6 @@ def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros) self._n = torch.empty(0, n, dtype=torch.int8) self._k = torch.empty(0, k, dtype=torch.int8) self._group_size = torch.empty(0, group_size, dtype=torch.int8) - weight_qvals, weight_scales, weight_zeros = _quantize( weights, self.group_size, self.nbit, self.has_weight_zeros @@ -105,7 +104,11 @@ def forward(self, x): assert x.dim() >= 2 if x.dim() == 2: return self._linear_op( - self.packed_weights, self._n, self._k, self._group_size, x + x, + self.packed_weights, + self._group_size, + self._n, + self._k, ) assert x.dim() >= 3 @@ -116,7 +119,7 @@ def forward(self, x): res = [ self._linear_op( - self.packed_weights, self._n, self._k, self._group_size, x[i, :, :] + x[i, :, :], self.packed_weights, self._group_size, self._n, self._k ) for i in range(x.shape[0]) ] @@ -203,14 +206,15 @@ def forward(self, x): def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): try: - if nbit in [2, 3, 4, 5]: - wzp_suffix = "z" if has_weight_zeros else "" + if nbit in [1, 2, 3, 4, 5]: + wzp_suffix = "" if has_weight_zeros else "0zp" return _Int8DynActIntxWeightQuantizedLinearNative( pack_weight_op=getattr( - torch.ops.torchao, f"_pack_weights_a8sz_w{nbit}s{wzp_suffix}" + torch.ops.torchao, + f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight", ), linear_op=getattr( - torch.ops.torchao, f"_linear_a8sz_w{nbit}s{wzp_suffix}" + torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" ), ) else: diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py index dfb999cb17..1966fd1589 100644 --- a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py @@ -30,7 +30,6 @@ else: torch.ops.load_library(libs[0]) - class TestInt8DynActIntxWeightQuantizer(unittest.TestCase): def test_accuracy(self): group_size = 128