Skip to content

Commit f9a68f9

Browse files
committed
[experimental][kleidi] rebase fixes with int to size_t
1 parent a014164 commit f9a68f9

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ const Ukernel get_ukernel() {
4040
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};
4141
}
4242

43-
int activation_data_size(int m, int k, int group_size) {
43+
size_t activation_data_size(int m, int k, int group_size) {
4444
(void)group_size; // unused
4545
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
4646
get_ukernel(), m, k);
@@ -57,7 +57,7 @@ void prepare_activation_data(
5757
get_ukernel(), activation_data, m, k, activations);
5858
}
5959

60-
int weight_data_size(int n, int k, int group_size) {
60+
size_t weight_data_size(int n, int k, int group_size) {
6161
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
6262
get_ukernel(), n, k, group_size);
6363
}
@@ -115,7 +115,7 @@ void kernel(
115115
clamp_max);
116116
}
117117

118-
size_t get_alignement() {
118+
size_t get_preferred_alignement() {
119119
return 16;
120120
}
121121
} // namespace neon_dotprod_1x4x32

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const Ukernel get_ukernel() {
3939
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod};
4040
}
4141

42-
int activation_data_size(int m, int k, int group_size) {
42+
size_t activation_data_size(int m, int k, int group_size) {
4343
(void) group_size; // unused
4444
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
4545
}
@@ -59,7 +59,7 @@ void prepare_activation_data(
5959
activations);
6060
}
6161

62-
int weight_data_size(int n, int k, int group_size) {
62+
size_t weight_data_size(int n, int k, int group_size) {
6363
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
6464
}
6565

@@ -116,7 +116,7 @@ void kernel(
116116
clamp_max);
117117
}
118118

119-
size_t get_alignement() {
119+
size_t get_preferred_alignement() {
120120
return 16;
121121
}
122122
} // namespace neon_dotprod_1x4x32

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
4343

4444
using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
4545

46-
int activation_data_size(const Ukernel ukernel, int m, int k) {
46+
size_t activation_data_size(const Ukernel ukernel, int m, int k) {
4747
auto lhs_packing = get_lhs_packing();
4848
return lhs_packing.get_lhs_packed_size(
4949
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
@@ -69,7 +69,7 @@ void prepare_activation_data(
6969
activation_data);
7070
}
7171

72-
int weight_data_size(const Ukernel ukernel, int n, int k, int group_size) {
72+
size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) {
7373
auto rhs_pack = get_rhs_packing();
7474
return rhs_pack.get_rhs_packed_size(
7575
n,

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ inline std::vector<uint8_t> get_random_lowbit_vector(int size, int nbit) {
4444
}
4545

4646
// TODO move these to a common utils
47-
uint16_t get_bf16_from_float(float f) {
47+
inline uint16_t
48+
get_bf16_from_float(float f) {
4849
uint16_t bf16;
4950
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
5051
memcpy(&bf16, &f, sizeof(uint16_t));
@@ -56,7 +57,8 @@ uint16_t get_bf16_from_float(float f) {
5657
return bf16;
5758
}
5859

59-
float get_float_from_bf16(uint16_t bf16) {
60+
inline float
61+
get_float_from_bf16(uint16_t bf16) {
6062
float f;
6163
const uint32_t i32 = (bf16 << 16);
6264
memcpy(&f, &i32, sizeof(uint32_t));

0 commit comments

Comments
 (0)