diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index 450c612878..2fd2f5391c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -36,7 +37,7 @@ TORCHAO_ALWAYS_INLINE void pack_uint_odd_bit_values( int variant) { constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); + assert(unpacked_size % variant == 0); uint8x16_t unpacked0; uint8x16_t unpacked1; @@ -103,7 +104,7 @@ TORCHAO_ALWAYS_INLINE void unpack_uint_odd_bit_values( int variant) { constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); + assert(unpacked_size % variant == 0); uint8x16_t unpacked0; uint8x16_t unpacked1; @@ -222,7 +223,7 @@ void pack_uint_values<2>( constexpr int nbit = 2; constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); + assert(unpacked_size % variant == 0); uint8x8_t unpacked0_8x8; uint8x8_t unpacked1_8x8; @@ -287,7 +288,7 @@ void unpack_uint_values<2>( constexpr int nbit = 2; constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); + assert(unpacked_size % variant == 0); uint8x8_t unpacked0_8x8; uint8x8_t unpacked1_8x8; @@ -394,7 +395,7 @@ void pack_uint_values<4>( constexpr int nbit = 4; constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); + assert(unpacked_size % variant == 0); uint8x16_t unpacked0; uint8x16_t unpacked1; @@ -435,7 +436,7 @@ void unpack_uint_values<4>( constexpr int nbit = 4; constexpr int bitsPerByte = 8; assert(unpacked_size * nbit / bitsPerByte == packed_size); - assert(packed_size % variant == 0); + assert(unpacked_size % variant == 0); uint8x16_t unpacked0; uint8x16_t unpacked1; @@ -507,6 +508,98 @@ void unpack_uint_values<5>( variant); } +// Benchmark utility to compare variants of uint6 packing +template <> +void pack_uint_values<6>( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 6; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(unpacked_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + switch (variant) { + case 4: + for (int i = 0; i < unpacked_size; i += 4) { + torchao::bitpacking::internal::pack_4_uint6_values( + packed + ((i * nbit) / bitsPerByte), unpacked + i); + } + break; + case 32: + for (int i = 0; i < unpacked_size; i += 32) { + unpacked0 = vld1q_u8(unpacked + i); + unpacked1 = vld1q_u8(unpacked + 16 + i); + torchao::bitpacking::internal::vec_pack_32_uint6_values( + packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + unpacked0 = vld1q_u8(unpacked + i); + unpacked1 = vld1q_u8(unpacked + 16 + i); + unpacked2 = vld1q_u8(unpacked + 32 + i); + unpacked3 = vld1q_u8(unpacked + 48 + i); + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1, unpacked2, unpacked3); + } + break; + } +} + +// Benchmark utility to compare variants of uint6 unpacking +template <> +void unpack_uint_values<6>( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 6; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(unpacked_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + switch (variant) { + case 4: + for (int i = 0; i < unpacked_size; i += 4) { + torchao::bitpacking::internal::unpack_4_uint6_values( + unpacked + i, packed + ((i * nbit) / bitsPerByte)); + } + break; + case 32: + for (int i = 0; i < unpacked_size; i += 32) { + torchao::bitpacking::internal::vec_unpack_32_uint6_values( + unpacked0, unpacked1, packed + ((i * nbit) / bitsPerByte)); + vst1q_u8(unpacked + i, unpacked0); + vst1q_u8(unpacked + 16 + i, unpacked1); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed + ((i * nbit) / bitsPerByte)); + vst1q_u8(unpacked + i, unpacked0); + vst1q_u8(unpacked + 16 + i, unpacked1); + vst1q_u8(unpacked + 32 + i, unpacked2); + vst1q_u8(unpacked + 48 + i, unpacked3); + } + break; + } +} + } // namespace template @@ -557,6 +650,8 @@ BENCHMARK(benchmark_pack_uint_values<4>)->ArgsProduct({{128}, {2, 16, 32}}); BENCHMARK(benchmark_unpack_uint_values<4>)->ArgsProduct({{128}, {2, 16, 32}}); BENCHMARK(benchmark_pack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_unpack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_pack_uint_values<6>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint_values<6>)->ArgsProduct({{128}, {4, 32, 64}}); // Run the benchmark BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp index 02a8d7ac98..dad6d67c25 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -238,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 6); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( @@ -248,6 +250,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 6); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( @@ -258,6 +262,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( + 6); // Run the benchmark BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 7029b7e493..b5276b94ce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -15,6 +15,7 @@ #include #include #include +#include #include namespace torchao { @@ -80,7 +81,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( // Currently supported values static_assert(nbit >= 1); - static_assert(nbit <= 5); + static_assert(nbit <= 6); // Shift unpacked values to nonnegative range int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); @@ -138,6 +139,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( torchao::bitpacking::internal::pack_8_uint5_values( packed + 15, buffer5 + 24); break; + case 6: + torchao::bitpacking::internal::vec_pack_32_uint6_values( + packed, shifted0, shifted1); + break; default: assert(false); } @@ -153,7 +158,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( // Currently supported values static_assert(nbit >= 1); - static_assert(nbit <= 5); + static_assert(nbit <= 6); uint8x16_t shifted0; uint8x16_t shifted1; @@ -208,6 +213,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( shifted0 = vld1q_u8(buffer5); shifted1 = vld1q_u8(buffer5 + 16); break; + case 6: + torchao::bitpacking::internal::vec_unpack_32_uint6_values( + shifted0, shifted1, packed); + break; default: assert(false); } @@ -230,7 +239,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( // Currently supported values static_assert(nbit >= 1); - static_assert(nbit <= 5); + static_assert(nbit <= 6); // Shift unpacked values to nonnegative range int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); @@ -262,6 +271,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( torchao::bitpacking::internal::vec_pack_64_uint5_values( packed, shifted0, shifted1, shifted2, shifted3); break; + case 6: + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed, shifted0, shifted1, shifted2, shifted3); + break; default: assert(false); } @@ -279,7 +292,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( // Currently supported values static_assert(nbit >= 1); - static_assert(nbit <= 5); + static_assert(nbit <= 6); uint8x16_t shifted0; uint8x16_t shifted1; @@ -309,6 +322,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( torchao::bitpacking::internal::vec_unpack_64_uint5_values( shifted0, shifted1, shifted2, shifted3, packed); break; + case 6: + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + shifted0, shifted1, shifted2, shifted3, packed); + break; default: assert(false); } @@ -337,7 +354,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( // Currently supported values static_assert(nbit >= 1); - static_assert(nbit <= 5); + static_assert(nbit <= 6); // Shift unpacked values to nonnegative range int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); @@ -403,6 +420,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( shifted6, shifted7); break; + case 6: + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed, shifted0, shifted1, shifted2, shifted3); + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed + 48, shifted4, shifted5, shifted6, shifted7); + break; default: assert(false); } @@ -424,7 +447,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( // Currently supported values static_assert(nbit >= 1); - static_assert(nbit <= 5); + static_assert(nbit <= 6); uint8x16_t shifted0; uint8x16_t shifted1; @@ -488,6 +511,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( shifted7, packed); break; + case 6: + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + shifted0, shifted1, shifted2, shifted3, packed); + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + shifted4, shifted5, shifted6, shifted7, packed + 48); + break; default: assert(false); } diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h new file mode 100644 index 0000000000..fd7535a022 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h @@ -0,0 +1,241 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +// This file contains bitpacking and unpacking methods for uint5. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Given 4 unpacked uint6 values: 01abcd, 23efgh, 45ijkl, 67mnop + // this function packs them as: + // b54: 67|45|23|01 (to hold upper 2 bits on all values) + // b3210_0: efgh|abcd (lower 4 bits for first 2 values) + // b3210_1: mnop|ijkl (lower 4 bits for last 2 values) + + // These are stored in packed as: b54, b3210_0, b3210_1 + // + // Input is 4 bytes + // Output is 6 * 4 bits/8 = 3 bytes + + // b54 + packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) | + ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2); + + // b3210_0 + packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4); + + // b3210_1 + packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4); +} + +TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpacks data packed by pack_4_uint6_values + // + // Input is 24 bits = 3 bytes + // Output is 4 bytes + + uint8_t b54 = packed[0]; + uint8_t b3210_0 = packed[1]; + uint8_t b3210_1 = packed[2]; + + unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15); + unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4); + + unpacked[2] = (b54 & 48) | (b3210_1 & 15); + unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1) { + // This function is a vectorized version of pack_8_uint6_values + // To understand it, please see pack_8_uint6_values first. + // Before each code section, there is a comment indicating the + // code in pack_8_uint6_values that is being vectorized + // + // Input is 32 bytes + // Output is 6*32= 192 bits = 24 bytes + + uint8x8_t b54; + uint8x8_t mask; + + // // b54 + // packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) | + // ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2); + mask = vdup_n_u8(48); + b54 = vshr_n_u8(vand_u8(vget_low_u8(unpacked0), mask), 4); + b54 = vorr_u8(b54, vshr_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 2)); + + b54 = vorr_u8(b54, vand_u8(vget_low_u8(unpacked1), mask)); + b54 = vorr_u8(b54, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 2)); + + vst1_u8(packed, b54); + + mask = vdup_n_u8(15); + uint8x8_t b3210; + + // b3210_0 + // packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4); + b3210 = vand_u8(vget_low_u8(unpacked0), mask); + b3210 = vorr_u8(b3210, vshl_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 4)); + vst1_u8(packed + 8, b3210); + + // b3210_1 + // packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4); + b3210 = vand_u8(vget_low_u8(unpacked1), mask); + b3210 = vorr_u8(b3210, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 4)); + vst1_u8(packed + 16, b3210); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + const uint8_t* packed) { + // Unpacks data packed by pack_32_uint6_values + // + // This function vectorizes vec_unpack_4_uint6_values + // To understand it, please see vec_unpack_4_uint6_values first. + // Before each code section, there is a comment indicating the + // code in vec_unpack_4_uint6_values that is being vectorized + + // Input is 24 bytes + // Output is 32 bytes + + uint8x8_t b54 = vld1_u8(packed); + uint8x8_t b3210; + uint8x8_t unpacked_tmp0; + uint8x8_t unpacked_tmp1; + + // unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15); + // unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4); + b3210 = vld1_u8(packed + 8); + + unpacked_tmp0 = vshl_n_u8(vand_u8(b54, vdup_n_u8(3)), 4); + unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b3210, vdup_n_u8(15))); + + unpacked_tmp1 = vshl_n_u8(vand_u8(b54, vdup_n_u8(12)), 2); + unpacked_tmp1 = vorr_u8(unpacked_tmp1, vshr_n_u8(b3210, 4)); + + unpacked0 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); + + // unpacked[2] = (b54 & 48) | (b3210_1 & 15); + // unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4); + b3210 = vld1_u8(packed + 16); + + unpacked_tmp0 = vand_u8(b54, vdup_n_u8(48)); + unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b3210, vdup_n_u8(15))); + + unpacked_tmp1 = vshr_n_u8(vand_u8(b54, vdup_n_u8(192)), 2); + unpacked_tmp1 = vorr_u8(unpacked_tmp1, vshr_n_u8(b3210, 4)); + + unpacked1 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + // This function is a vectorized version of pack_4_uint6_values + // To understand it, please see pack_4_uint6_values first. + // Before each code section, there is a comment indicating the + // code in pack_4_uint6_values that is being vectorized + // + // Input is 64 bytes + // Output is 6*64= 384 bits = 48 bytes + + uint8x16_t b54; + uint8x16_t mask; + + // b54 + // packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) | + // ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2); + mask = vdupq_n_u8(48); + b54 = vshrq_n_u8(vandq_u8(unpacked0, mask), 4); + b54 = vorrq_u8(b54, vshrq_n_u8(vandq_u8(unpacked1, mask), 2)); + b54 = vorrq_u8(b54, vandq_u8(unpacked2, mask)); + b54 = vorrq_u8(b54, vshlq_n_u8(vandq_u8(unpacked3, mask), 2)); + + vst1q_u8(packed, b54); + + mask = vdupq_n_u8(15); + uint8x16_t b3210; + + // b3210_0 + // packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4); + b3210 = vandq_u8(unpacked0, mask); + b3210 = vorrq_u8(b3210, vshlq_n_u8(vandq_u8(unpacked1, mask), 4)); + vst1q_u8(packed + 16, b3210); + + // b3210_1 + // packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4); + b3210 = vandq_u8(unpacked2, mask); + b3210 = vorrq_u8(b3210, vshlq_n_u8(vandq_u8(unpacked3, mask), 4)); + vst1q_u8(packed + 32, b3210); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + // Unpacks data packed by pack_64_uint6_values + // + // This function vectorizes vec_unpack_4_uint6_values + // To understand it, please see vec_unpack_4_uint6_values first. + // Before each code section, there is a comment indicating the + // code in vec_unpack_4_uint6_values that is being vectorized + + // Input is 48 bytes + // Output is 64 bytes + + uint8x16_t b54 = vld1q_u8(packed); + uint8x16_t b3210; + + // unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15); + // unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4); + b3210 = vld1q_u8(packed + 16); + + unpacked0 = vshlq_n_u8(vandq_u8(b54, vdupq_n_u8(3)), 4); + unpacked0 = vorrq_u8(unpacked0, vandq_u8(b3210, vdupq_n_u8(15))); + + unpacked1 = vshlq_n_u8(vandq_u8(b54, vdupq_n_u8(12)), 2); + unpacked1 = vorrq_u8(unpacked1, vshrq_n_u8(b3210, 4)); + + // unpacked[2] = (b54 & 48) | (b3210_1 & 15); + // unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4); + b3210 = vld1q_u8(packed + 32); + + unpacked2 = vandq_u8(b54, vdupq_n_u8(48)); + unpacked2 = vorrq_u8(unpacked2, vandq_u8(b3210, vdupq_n_u8(15))); + + unpacked3 = vshrq_n_u8(vandq_u8(b54, vdupq_n_u8(192)), 2); + unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(b3210, 4)); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index 92dceb16ed..ae9c5c5344 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -487,6 +488,86 @@ TEST(test_bitpacking_128_uint5_values, PackUnpackAreSame) { } } +TEST(test_bitpacking_4_uint6_values, PackUnpackAreSame) { + int unpacked_bytes = 4; + int packed_bytes = 3; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_4_uint6_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_4_uint6_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) { + int unpacked_bytes = 32; + int packed_bytes = 24; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + + input0 = vld1q_u8(input.data()); + input1 = vld1q_u8(input.data() + 16); + torchao::bitpacking::internal::vec_pack_32_uint6_values( + packed.data(), input0, input1); + torchao::bitpacking::internal::vec_unpack_32_uint6_values( + unpacked0, unpacked1, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + } +} + +TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 48; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed.data(), + input0, + input1, + input2, + input3); + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + // Universal bitpacking tests template void test_bitpacking_32_lowbit_values() { @@ -652,17 +733,20 @@ TEST_BITPACKING_32_LOWBIT_VALUES(2); TEST_BITPACKING_32_LOWBIT_VALUES(3); TEST_BITPACKING_32_LOWBIT_VALUES(4); TEST_BITPACKING_32_LOWBIT_VALUES(5); +TEST_BITPACKING_32_LOWBIT_VALUES(6); TEST_BITPACKING_64_LOWBIT_VALUES(1); TEST_BITPACKING_64_LOWBIT_VALUES(2); TEST_BITPACKING_64_LOWBIT_VALUES(3); TEST_BITPACKING_64_LOWBIT_VALUES(4); TEST_BITPACKING_64_LOWBIT_VALUES(5); +TEST_BITPACKING_64_LOWBIT_VALUES(6); TEST_BITPACKING_128_LOWBIT_VALUES(1); TEST_BITPACKING_128_LOWBIT_VALUES(2); TEST_BITPACKING_128_LOWBIT_VALUES(3); TEST_BITPACKING_128_LOWBIT_VALUES(4); TEST_BITPACKING_128_LOWBIT_VALUES(5); +TEST_BITPACKING_128_LOWBIT_VALUES(6); #endif // defined(__aarch64__) || defined(__ARM_NEON)