Skip to content

Commit ae3e7c6

Browse files
authored
Change arg order in lowbit linear ops to match aten (#982)
Change arg order in lowbit linear ops to match aten (#982) Summary: Pull Request resolved: #982 Discussed with Manuel to align on arg order between CPU/MPS ops. Reviewed By: digantdesai, manuelcandales Differential Revision: D63422524
1 parent 68e1886 commit ae3e7c6

File tree

14 files changed

+164
-92
lines changed

14 files changed

+164
-92
lines changed

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,14 @@ Tensor pack_weights_with_zeros_meta(
218218
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
219219
template <int weight_nbit, bool has_weight_zeros>
220220
Tensor linear_out_cpu(
221+
const Tensor& activations,
221222
const Tensor& packed_weights,
222223
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
223224
// int64_t when supported by AOTI Currently they are tensors with size
224225
// equal to (0, the int they wrap)
226+
const Tensor& group_size_tensor,
225227
const Tensor& n_tensor,
226228
const Tensor& k_tensor,
227-
const Tensor& group_size_tensor,
228-
const Tensor& activations,
229229
Tensor& out) {
230230
int n = n_tensor.size(1);
231231
int k = k_tensor.size(1);
@@ -307,21 +307,21 @@ Tensor linear_out_cpu(
307307
#ifdef USE_ATEN
308308
template <int weight_nbit, bool has_weight_zeros>
309309
Tensor linear_cpu(
310+
const Tensor& activations,
310311
const Tensor& packed_weights,
311312
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
312313
// int64_t when supported by AOTI Currently they are tensors with size
313314
// equal to (0, the int they wrap)
314-
const Tensor& n_tensor,
315-
const Tensor& k_tensor,
316315
const Tensor& group_size_tensor,
317-
const Tensor& activations) {
316+
const Tensor& n_tensor,
317+
const Tensor& k_tensor) {
318318
Tensor output_tensor = torch::empty({}, torch::kFloat32);
319319
linear_out_cpu<weight_nbit, has_weight_zeros>(
320+
activations,
320321
packed_weights,
322+
group_size_tensor,
321323
n_tensor,
322324
k_tensor,
323-
group_size_tensor,
324-
activations,
325325
output_tensor);
326326
return output_tensor;
327327
}
@@ -330,14 +330,14 @@ Tensor linear_cpu(
330330
#ifdef USE_ATEN
331331
template <int weight_nbit, bool has_weight_zeros>
332332
Tensor linear_meta(
333+
const Tensor& activations,
333334
const Tensor& packed_weights,
334335
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
335336
// int64_t when supported by AOTI
336337
// Currently they are tensors with size equal to (0, the int they wrap)
337-
const Tensor& n_tensor,
338-
const Tensor& k_tensor,
339338
const Tensor& group_size_tensor,
340-
const Tensor& activations) {
339+
const Tensor& n_tensor,
340+
const Tensor& k_tensor) {
341341
int n = n_tensor.size(1);
342342
int k = k_tensor.size(1);
343343
CHECK_MSG(n >= 1, "n must be >= 1");

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,67 +6,78 @@
66

77
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
88

9-
#define DEFINE_OP(weight_nbit) \
10-
m.def( \
11-
"_pack_weights_a8sz_w" #weight_nbit \
12-
"s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); \
13-
m.def( \
14-
"_pack_weights_a8sz_w" #weight_nbit \
15-
"sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); \
16-
m.def( \
17-
"_linear_a8sz_w" #weight_nbit \
18-
"s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); \
19-
m.def( \
20-
"_linear_a8sz_w" #weight_nbit \
21-
"sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); \
22-
m.def( \
23-
"_linear_a8sz_w" #weight_nbit \
24-
"s.out(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations, *, Tensor(a!) out) -> Tensor(a!)"); \
25-
m.def( \
26-
"_linear_a8sz_w" #weight_nbit \
27-
"sz.out(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations, *, Tensor(a!) out) -> Tensor(a!)")
9+
#define DEFINE_OP(weight_nbit) \
10+
m.def( \
11+
"_pack_8bit_act_" #weight_nbit \
12+
"bit0zp_weight(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); \
13+
m.def( \
14+
"_pack_8bit_act_" #weight_nbit \
15+
"bit_weight(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); \
16+
m.def( \
17+
"_linear_8bit_act_" #weight_nbit \
18+
"bit0zp_weight(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k) -> Tensor"); \
19+
m.def( \
20+
"_linear_8bit_act_" #weight_nbit \
21+
"bit_weight(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k) -> Tensor"); \
22+
m.def( \
23+
"_linear_8bit_act_" #weight_nbit \
24+
"bit0zp_weight.out(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k, *, Tensor(a!) out) -> Tensor(a!)"); \
25+
m.def( \
26+
"_linear_8bit_act_" #weight_nbit \
27+
"bit_weight.out(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k, *, Tensor(a!) out) -> Tensor(a!)")
2828

29-
#define DEFINE_CPU_IMPL(weight_nbit) \
30-
m.impl( \
31-
"_pack_weights_a8sz_w" #weight_nbit "s", \
32-
&pack_weights_without_zeros_cpu<weight_nbit>); \
33-
m.impl( \
34-
"_pack_weights_a8sz_w" #weight_nbit "sz", \
35-
&pack_weights_with_zeros_cpu<weight_nbit>); \
36-
m.impl("_linear_a8sz_w" #weight_nbit "s", &linear_cpu<weight_nbit, false>); \
37-
m.impl("_linear_a8sz_w" #weight_nbit "sz", &linear_cpu<weight_nbit, true>); \
38-
m.impl( \
39-
"_linear_a8sz_w" #weight_nbit "s.out", \
40-
&linear_out_cpu<weight_nbit, false>); \
41-
m.impl( \
42-
"_linear_a8sz_w" #weight_nbit "sz.out", \
29+
#define DEFINE_CPU_IMPL(weight_nbit) \
30+
m.impl( \
31+
"_pack_8bit_act_" #weight_nbit "bit0zp_weight", \
32+
&pack_weights_without_zeros_cpu<weight_nbit>); \
33+
m.impl( \
34+
"_pack_8bit_act_" #weight_nbit "bit_weight", \
35+
&pack_weights_with_zeros_cpu<weight_nbit>); \
36+
m.impl( \
37+
"_linear_8bit_act_" #weight_nbit "bit0zp_weight", \
38+
&linear_cpu<weight_nbit, false>); \
39+
m.impl( \
40+
"_linear_8bit_act_" #weight_nbit "bit_weight", \
41+
&linear_cpu<weight_nbit, true>); \
42+
m.impl( \
43+
"_linear_8bit_act_" #weight_nbit "bit0zp_weight.out", \
44+
&linear_out_cpu<weight_nbit, false>); \
45+
m.impl( \
46+
"_linear_8bit_act_" #weight_nbit "bit_weight.out", \
4347
&linear_out_cpu<weight_nbit, true>)
4448

45-
#define DEFINE_META_IMPL(weight_nbit) \
46-
m.impl( \
47-
"_pack_weights_a8sz_w" #weight_nbit "s", \
48-
&pack_weights_without_zeros_meta<weight_nbit>); \
49-
m.impl( \
50-
"_pack_weights_a8sz_w" #weight_nbit "sz", \
51-
&pack_weights_with_zeros_meta<weight_nbit>); \
52-
m.impl("_linear_a8sz_w" #weight_nbit "s", &linear_meta<weight_nbit, false>); \
53-
m.impl("_linear_a8sz_w" #weight_nbit "sz", &linear_meta<weight_nbit, true>);
49+
#define DEFINE_META_IMPL(weight_nbit) \
50+
m.impl( \
51+
"_pack_8bit_act_" #weight_nbit "bit0zp_weight", \
52+
&pack_weights_without_zeros_meta<weight_nbit>); \
53+
m.impl( \
54+
"_pack_8bit_act_" #weight_nbit "bit_weight", \
55+
&pack_weights_with_zeros_meta<weight_nbit>); \
56+
m.impl( \
57+
"_linear_8bit_act_" #weight_nbit "bit0zp_weight", \
58+
&linear_meta<weight_nbit, false>); \
59+
m.impl( \
60+
"_linear_8bit_act_" #weight_nbit "bit_weight", \
61+
&linear_meta<weight_nbit, true>);
5462

5563
TORCH_LIBRARY(torchao, m) {
64+
DEFINE_OP(1);
5665
DEFINE_OP(2);
5766
DEFINE_OP(3);
5867
DEFINE_OP(4);
5968
DEFINE_OP(5);
6069
}
6170

6271
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
72+
DEFINE_CPU_IMPL(1);
6373
DEFINE_CPU_IMPL(2);
6474
DEFINE_CPU_IMPL(3);
6575
DEFINE_CPU_IMPL(4);
6676
DEFINE_CPU_IMPL(5);
6777
}
6878

6979
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
80+
DEFINE_META_IMPL(1);
7081
DEFINE_META_IMPL(2);
7182
DEFINE_META_IMPL(3);
7283
DEFINE_META_IMPL(4);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// Unlike ATen, ExecuTorch op registration appears to only allow on
8+
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
9+
// file is needed for each variant
10+
11+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
12+
13+
namespace {
14+
Tensor _op_out(
15+
RuntimeContext& ctx,
16+
const Tensor& activations,
17+
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
19+
const Tensor& n_tensor,
20+
const Tensor& k_tensor,
21+
Tensor& out) {
22+
(void)ctx;
23+
linear_out_cpu</*weight_nbit*/ 1, /*has_weight_zeros*/ false>(
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
25+
return out;
26+
}
27+
} // namespace
28+
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_1bit0zp_weight.out", _op_out);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// Unlike ATen, ExecuTorch op registration appears to only allow on
8+
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
9+
// file is needed for each variant
10+
11+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
12+
13+
namespace {
14+
Tensor _op_out(
15+
RuntimeContext& ctx,
16+
const Tensor& activations,
17+
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
19+
const Tensor& n_tensor,
20+
const Tensor& k_tensor,
21+
Tensor& out) {
22+
(void)ctx;
23+
linear_out_cpu</*weight_nbit*/ 1, /*has_weight_zeros*/ true>(
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
25+
return out;
26+
}
27+
} // namespace
28+
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_1bit_weight.out", _op_out);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2s.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
namespace {
1414
Tensor _op_out(
1515
RuntimeContext& ctx,
16+
const Tensor& activations,
1617
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
1719
const Tensor& n_tensor,
1820
const Tensor& k_tensor,
19-
const Tensor& group_size_tensor,
20-
const Tensor& activations,
2121
Tensor& out) {
2222
(void)ctx;
2323
linear_out_cpu</*weight_nbit*/ 2, /*has_weight_zeros*/ false>(
24-
packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out);
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
2525
return out;
2626
}
2727
} // namespace
2828

29-
EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w2s.out", _op_out);
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_2bit0zp_weight.out", _op_out);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
namespace {
1414
Tensor _op_out(
1515
RuntimeContext& ctx,
16+
const Tensor& activations,
1617
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
1719
const Tensor& n_tensor,
1820
const Tensor& k_tensor,
19-
const Tensor& group_size_tensor,
20-
const Tensor& activations,
2121
Tensor& out) {
2222
(void)ctx;
2323
linear_out_cpu</*weight_nbit*/ 2, /*has_weight_zeros*/ true>(
24-
packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out);
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
2525
return out;
2626
}
2727
} // namespace
2828

29-
EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w2sz.out", _op_out);
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_2bit_weight.out", _op_out);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3s.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
namespace {
1414
Tensor _op_out(
1515
RuntimeContext& ctx,
16+
const Tensor& activations,
1617
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
1719
const Tensor& n_tensor,
1820
const Tensor& k_tensor,
19-
const Tensor& group_size_tensor,
20-
const Tensor& activations,
2121
Tensor& out) {
2222
(void)ctx;
2323
linear_out_cpu</*weight_nbit*/ 3, /*has_weight_zeros*/ false>(
24-
packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out);
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
2525
return out;
2626
}
2727
} // namespace
2828

29-
EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w3s.out", _op_out);
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_3bit0zp_weight.out", _op_out);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w3sz.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
namespace {
1414
Tensor _op_out(
1515
RuntimeContext& ctx,
16+
const Tensor& activations,
1617
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
1719
const Tensor& n_tensor,
1820
const Tensor& k_tensor,
19-
const Tensor& group_size_tensor,
20-
const Tensor& activations,
2121
Tensor& out) {
2222
(void)ctx;
2323
linear_out_cpu</*weight_nbit*/ 3, /*has_weight_zeros*/ true>(
24-
packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out);
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
2525
return out;
2626
}
2727
} // namespace
2828

29-
EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w3sz.out", _op_out);
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_3bit_weight.out", _op_out);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
namespace {
1414
Tensor _op_out(
1515
RuntimeContext& ctx,
16+
const Tensor& activations,
1617
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
1719
const Tensor& n_tensor,
1820
const Tensor& k_tensor,
19-
const Tensor& group_size_tensor,
20-
const Tensor& activations,
2121
Tensor& out) {
2222
(void)ctx;
2323
linear_out_cpu</*weight_nbit*/ 4, /*has_weight_zeros*/ false>(
24-
packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out);
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
2525
return out;
2626
}
2727
} // namespace
2828

29-
EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w4s.out", _op_out);
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_4bit0zp_weight.out", _op_out);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4sz.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
namespace {
1414
Tensor _op_out(
1515
RuntimeContext& ctx,
16+
const Tensor& activations,
1617
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
1719
const Tensor& n_tensor,
1820
const Tensor& k_tensor,
19-
const Tensor& group_size_tensor,
20-
const Tensor& activations,
2121
Tensor& out) {
2222
(void)ctx;
2323
linear_out_cpu</*weight_nbit*/ 4, /*has_weight_zeros*/ true>(
24-
packed_weights, n_tensor, k_tensor, group_size_tensor, activations, out);
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
2525
return out;
2626
}
2727
} // namespace
2828

29-
EXECUTORCH_LIBRARY(torchao, "_linear_a8sz_w4sz.out", _op_out);
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_4bit_weight.out", _op_out);

0 commit comments

Comments
 (0)