diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index b602146ecf..0ff0ea6de8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -601,6 +602,49 @@ void unpack_uint_values<6>( } } +// Benchmark utility to compare variants of uint7 packing. +template <> +void pack_uint_values<7>( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 7; + pack_uint_odd_bit_values( + torchao::bitpacking::internal::pack_8_uint7_values, + torchao::bitpacking::internal::vec_pack_64_uint7_values, + torchao::bitpacking::internal::vec_pack_128_uint7_values, + nbit, + packed, + unpacked, + packed_size, + unpacked_size, + variant); +} + +// Benchmark utility to compare variants of uint7 unpacking. +template <> +void unpack_uint_values<7>( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 7; + unpack_uint_odd_bit_values( + torchao::bitpacking::internal::unpack_8_uint7_values, + torchao::bitpacking::internal::vec_unpack_64_uint7_values, + torchao::bitpacking::internal::vec_unpack_128_uint7_values, + nbit, + unpacked, + packed, + unpacked_size, + packed_size, + variant); +} + + } // namespace template @@ -653,6 +697,8 @@ BENCHMARK(benchmark_pack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_unpack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_pack_uint_values<6>)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_unpack_uint_values<6>)->ArgsProduct({{128}, {4, 32, 64}}); +BENCHMARK(benchmark_pack_uint_values<7>)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint_values<7>)->ArgsProduct({{128}, {8, 64, 128}}); // Run the benchmark BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp index 688abd717f..1a526ca4db 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -243,6 +243,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT 5); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 6); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 7); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( @@ -255,6 +257,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT 5); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 6); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 7); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( @@ -267,6 +271,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT 5); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( 6); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( + 7); // Run the benchmark BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 340f295e50..5442f8319b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -79,10 +80,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( static_assert(nbit < 8); static_assert(nbit >= 1); - // Currently supported values - static_assert(nbit >= 1); - static_assert(nbit <= 6); - // Shift unpacked values to nonnegative range int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); @@ -144,6 +141,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( torchao::bitpacking::internal::vec_pack_32_uint6_values( packed, shifted0, shifted1); break; + case 7: + uint8_t buffer7[32]; + vst1q_u8(buffer7, shifted0); + vst1q_u8(buffer7 + 16, shifted1); + + torchao::bitpacking::internal::pack_8_uint7_values(packed, buffer7); + torchao::bitpacking::internal::pack_8_uint7_values(packed + 7, buffer7 + 8); + torchao::bitpacking::internal::pack_8_uint7_values(packed + 14, buffer7 + 16); + torchao::bitpacking::internal::pack_8_uint7_values(packed + 21, buffer7 + 24); + break; default: assert(false); } @@ -157,10 +164,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( static_assert(nbit < 8); static_assert(nbit >= 1); - // Currently supported values - static_assert(nbit >= 1); - static_assert(nbit <= 6); - uint8x16_t shifted0; uint8x16_t shifted1; @@ -219,6 +222,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( torchao::bitpacking::internal::vec_unpack_32_uint6_values( shifted0, shifted1, packed); break; + case 7: + uint8_t buffer7[32]; + torchao::bitpacking::internal::unpack_8_uint7_values(buffer7, packed); + torchao::bitpacking::internal::unpack_8_uint7_values( + buffer7 + 8, packed + 7); + torchao::bitpacking::internal::unpack_8_uint7_values( + buffer7 + 16, packed + 14); + torchao::bitpacking::internal::unpack_8_uint7_values( + buffer7 + 24, packed + 21); + shifted0 = vld1q_u8(buffer7); + shifted1 = vld1q_u8(buffer7 + 16); + break; default: assert(false); } @@ -239,10 +254,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( static_assert(nbit < 8); static_assert(nbit >= 1); - // Currently supported values - static_assert(nbit >= 1); - static_assert(nbit <= 6); - // Shift unpacked values to nonnegative range int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); @@ -277,6 +288,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( torchao::bitpacking::internal::vec_pack_64_uint6_values( packed, shifted0, shifted1, shifted2, shifted3); break; + case 7: + torchao::bitpacking::internal::vec_pack_64_uint7_values( + packed, shifted0, shifted1, shifted2, shifted3); + break; default: assert(false); } @@ -292,10 +307,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( static_assert(nbit < 8); static_assert(nbit >= 1); - // Currently supported values - static_assert(nbit >= 1); - static_assert(nbit <= 6); - uint8x16_t shifted0; uint8x16_t shifted1; uint8x16_t shifted2; @@ -328,6 +339,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( torchao::bitpacking::internal::vec_unpack_64_uint6_values( shifted0, shifted1, shifted2, shifted3, packed); break; + case 7: + torchao::bitpacking::internal::vec_unpack_64_uint7_values( + shifted0, shifted1, shifted2, shifted3, packed); + break; default: assert(false); } @@ -354,10 +369,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( static_assert(nbit < 8); static_assert(nbit >= 1); - // Currently supported values - static_assert(nbit >= 1); - static_assert(nbit <= 6); - // Shift unpacked values to nonnegative range int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); @@ -428,6 +439,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( torchao::bitpacking::internal::vec_pack_64_uint6_values( packed + 48, shifted4, shifted5, shifted6, shifted7); break; + case 7: + torchao::bitpacking::internal::vec_pack_128_uint7_values( + packed, + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7); + break; default: assert(false); } @@ -447,10 +470,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( static_assert(nbit < 8); static_assert(nbit >= 1); - // Currently supported values - static_assert(nbit >= 1); - static_assert(nbit <= 6); - uint8x16_t shifted0; uint8x16_t shifted1; uint8x16_t shifted2; @@ -519,6 +538,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( torchao::bitpacking::internal::vec_unpack_64_uint6_values( shifted4, shifted5, shifted6, shifted7, packed + 48); break; + case 7: + torchao::bitpacking::internal::vec_unpack_128_uint7_values( + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7, + packed); + break; default: assert(false); } diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h new file mode 100644 index 0000000000..1fc2a8d5cb --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h @@ -0,0 +1,277 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +// This file contains bitpacking and unpacking methods for uint7. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_8_uint7_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Given 8 unpacked uint7 values: + // aaa aaaa, bbb bbbb, ccc cccc, ddd dddd, + // eee eeee, fff ffff, ggg gggg, 123 4567, + // This function produces the following packing: + // packed[0] = 7aaa aaaa + // packed[1] = 6bbb bbbb + // packed[2] = 5ccc cccc + // packed[3] = 4ddd dddd + // packed[4] = 3eee eeee + // packed[5] = 2fff ffff + // packed[6] = 1ggg gggg + // + // Input is 8 bytes + // Output is 7 * 8 bits = 56 bits = 7 bytes + + // Split the bits of unpacked[7]. + uint8_t mask[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + uint8x8_t unpacked7 = vtst_u8(vdup_n_u8(unpacked[7]), vld1_u8(mask)); + // At this point, each byte in unpacked7 is all ones or all zeroes, depending + // on whether the corresponding bit in unpacked7 was one or zero. + // The next statement combines 7 bits from unpacked[i] with the i-th bit from + // unpacked7. + uint8x8_t p = vsli_n_u8(vld1_u8(unpacked), unpacked7, 7); + + packed[0] = vget_lane_u8(p, 0); + packed[1] = vget_lane_u8(p, 1); + packed[2] = vget_lane_u8(p, 2); + packed[3] = vget_lane_u8(p, 3); + packed[4] = vget_lane_u8(p, 4); + packed[5] = vget_lane_u8(p, 5); + packed[6] = vget_lane_u8(p, 6); +} + +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint7_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpacks 8 uint7 vales packed by pack_8_uint7_values. + // Load the 7 packed bytes into vector. + uint8_t temp[8] = { + packed[0], + packed[1], + packed[2], + packed[3], + packed[4], + packed[5], + packed[6], + /*ignored*/ 0}; + uint8x8_t v = vld1_u8(temp); + int8_t shift[8] = {-7, -6, -5, -4, -3, -2, -1, 0}; + // The following and and shift operations will produce a vector with the + // 0123 4567 bits from the last packed uint7: + // 0000 0007 + // 0000 0060 + // 0000 0500 + // 0000 4000 + // 0003 0000 + // 0020 0000 + // 0100 0000 + // Which can then be added to obtain unpacked[7]. + unpacked[7] = + vaddv_u8(vshl_u8(vand_u8(v, vdup_n_u8(0b1000'0000u)), vld1_s8(shift))); + // All other unpacked values are just the corresponding packed value with the + // last bit cleared. + v = vand_u8(v, vdup_n_u8(0b0111'1111u)); + unpacked[0] = vget_lane_u8(v, 0); + unpacked[1] = vget_lane_u8(v, 1); + unpacked[2] = vget_lane_u8(v, 2); + unpacked[3] = vget_lane_u8(v, 3); + unpacked[4] = vget_lane_u8(v, 4); + unpacked[5] = vget_lane_u8(v, 5); + unpacked[6] = vget_lane_u8(v, 6); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint7_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + // This function is a vectorized version of pack_8_uint7_values + // To understand it, please see pack_8_uint7_values first. The + // main idea is to use the last bit of each packed uint8_t to + // store the last uint7. + // + // Input is 64 bytes + // Output is 7*64= 448 bits = 56 bytes + // + // Insert one bit from the elements from the last 8-value vector into the + // first 7 8-value vectors. If bits of `last` are labeled 0123 4567 then + // bit 7 is inserted into the first packed uint8_t (in a vectorized manner), + // then bit 6, and so on. + uint8x8_t last = vget_high_u8(unpacked3); + // Insert bit 7 of last into the first packed uint8_t. + vst1_u8(packed, vsli_n_u8(vget_low_u8(unpacked0), last, 7)); + + // Repeat for the i-th bit of `last` and the remaining 8-value vectors. + // Pack bit 6 from 0123 4567. + vst1_u8( + packed + 8, vsli_n_u8(vget_high_u8(unpacked0), vshr_n_u8(last, 1), 7)); + // Pack bit 5 from 0123 4567, etc. + vst1_u8( + packed + 16, vsli_n_u8(vget_low_u8(unpacked1), vshr_n_u8(last, 2), 7)); + vst1_u8( + packed + 24, vsli_n_u8(vget_high_u8(unpacked1), vshr_n_u8(last, 3), 7)); + vst1_u8( + packed + 32, vsli_n_u8(vget_low_u8(unpacked2), vshr_n_u8(last, 4), 7)); + vst1_u8( + packed + 40, vsli_n_u8(vget_high_u8(unpacked2), vshr_n_u8(last, 5), 7)); + vst1_u8( + packed + 48, vsli_n_u8(vget_low_u8(unpacked3), vshr_n_u8(last, 6), 7)); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint7_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + // Unpacks data packed by vec_pack_64_uint7_values. + // Please see vec_pack_64_uint7_values first. + // + // Input is 7*64 = 448 bits = 56 bytes + // Output is 64 bytes. + const uint8x8_t mask = vdup_n_u8(0b0111'1111u); + // Starting from the last packed byte, extract the most significant bit + // to reconstruct the last 8-value vector. If the last uint7 value + // is labeled as 0123 4567 and X are bits we don't care about, we start + // with last_high = 01XX XXXX + uint8x8_t last_low = vld1_u8(packed + 48); + uint8x8_t last_high = vshr_n_u8(last_low, 1); + + uint8x8_t high = vld1_u8(packed + 40); + uint8x8_t low = vld1_u8(packed + 32); + // last_high = 012X XXXX + last_high = vsri_n_u8(last_high, high, 2); + // last_high = 0123 XXXX + last_high = vsri_n_u8(last_high, low, 3); + unpacked2 = vcombine_u8(vand_u8(low, mask), vand_u8(high, mask)); + + high = vld1_u8(packed + 24); + low = vld1_u8(packed + 16); + // last_high = 0123 4XXX + last_high = vsri_n_u8(last_high, high, 4); + // last_high = 0123 45XX + last_high = vsri_n_u8(last_high, low, 5); + unpacked1 = vcombine_u8(vand_u8(low, mask), vand_u8(high, mask)); + + high = vld1_u8(packed + 8); + low = vld1_u8(packed); + // last_high = 0123 456X + last_high = vsri_n_u8(last_high, high, 6); + // last_high = 0123 4567 + last_high = vsri_n_u8(last_high, low, 7); + unpacked0 = vcombine_u8(vand_u8(low, mask), vand_u8(high, mask)); + + unpacked3 = vcombine_u8(vand_u8(last_low, mask), vand_u8(last_high, mask)); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint7_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3, + const uint8x16_t& unpacked4, + const uint8x16_t& unpacked5, + const uint8x16_t& unpacked6, + const uint8x16_t& unpacked7) { + // This function is a vectorized version of pack_8_uint7_values + // To understand it, please see pack_8_uint7_values first. The + // main idea is to use the last bit of each packed uint8_t to + // store the last uint7. + // + // Input is 128 bytes + // Output is 7*128= 896 bits = 112 bytes + + // Pack an 8-element vector using the first bit from each element in unpacked7. + // If those elements are labeled 0123 4567 then the following line does: + // Shift left insert by 7: 7 | low_7_bits(unpacked0) + vst1q_u8(packed, vsliq_n_u8(unpacked0, unpacked7, 7)); + // Shift right by 1: 0123 4567 -> 0012 3456 + // Shift left the insert by 7 the above: 6 | low_7_bits(unpacked1) + vst1q_u8(packed + 16, vsliq_n_u8(unpacked1, vshrq_n_u8(unpacked7, 1), 7)); + // Shift right by 2: 0123 4567 -> 0001 2345 + // Shift left the insert by 7 the above: 5 | low_7_bits(unpacked2) + vst1q_u8(packed + 16 * 2, vsliq_n_u8(unpacked2, vshrq_n_u8(unpacked7, 2), 7)); + // And so on and so forth... + vst1q_u8(packed + 16 * 3, vsliq_n_u8(unpacked3, vshrq_n_u8(unpacked7, 3), 7)); + vst1q_u8(packed + 16 * 4, vsliq_n_u8(unpacked4, vshrq_n_u8(unpacked7, 4), 7)); + vst1q_u8(packed + 16 * 5, vsliq_n_u8(unpacked5, vshrq_n_u8(unpacked7, 5), 7)); + vst1q_u8(packed + 16 * 6, vsliq_n_u8(unpacked6, vshrq_n_u8(unpacked7, 6), 7)); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint7_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + uint8x16_t& unpacked4, + uint8x16_t& unpacked5, + uint8x16_t& unpacked6, + uint8x16_t& unpacked7, + const uint8_t* packed) { + // Unpacks data packed by vec_pack_128_uint7_values + // Please see vec_pack_128_uint7_values first. + // + // Input is 128 bytes + // Output is 7*128= 896 bits = 112 bytes + const uint8x16_t mask = vdupq_n_u8(0b0111'1111u); + // Starting from the last packed byte, extract the most significant bit + // to reconstruct the last 8-value vector. If the last uint7 value + // is labeled as 0123 4567 and X are bits we don't care about, we start + // with unpacked7 = 01XX XXXX + unpacked6 = vld1q_u8(packed + 16 * 6); + unpacked7 = vshrq_n_u8(unpacked6, 1); + unpacked6 = vandq_u8(unpacked6, mask); + + unpacked5 = vld1q_u8(packed + 16 * 5); + // unpacked7 = 012X XXXX + unpacked7 = vsriq_n_u8(unpacked7, unpacked5, 2); + unpacked5 = vandq_u8(unpacked5, mask); + + unpacked4 = vld1q_u8(packed + 16 * 4); + // unpacked7 = 0123 XXXX + unpacked7 = vsriq_n_u8(unpacked7, unpacked4, 3); + unpacked4 = vandq_u8(unpacked4, mask); + + unpacked3 = vld1q_u8(packed + 16 * 3); + // unpacked7 = 0123 4XXX + unpacked7 = vsriq_n_u8(unpacked7, unpacked3, 4); + unpacked3 = vandq_u8(unpacked3, mask); + + unpacked2 = vld1q_u8(packed + 16 * 2); + // unpacked7 = 0123 45XX + unpacked7 = vsriq_n_u8(unpacked7, unpacked2, 5); + unpacked2 = vandq_u8(unpacked2, mask); + + unpacked1 = vld1q_u8(packed + 16); + // unpacked7 = 0123 456X + unpacked7 = vsriq_n_u8(unpacked7, unpacked1, 6); + unpacked1 = vandq_u8(unpacked1, mask); + + unpacked0 = vld1q_u8(packed); + // unpacked7 = 0123 4567 + unpacked7 = vsriq_n_u8(unpacked7, unpacked0, 7); + unpacked0 = vandq_u8(unpacked0, mask); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index ef51fd7d43..c891bdcef3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -560,6 +561,114 @@ TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) { } } +TEST(test_bitpacking_8_uint7_values, PackUnpackAreSame) { + int unpacked_bytes = 8; + int packed_bytes = 7; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 7); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_8_uint7_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_8_uint7_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint7_values, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 56; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 7); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint7_values( + packed.data(), input0, input1, input2, input3); + torchao::bitpacking::internal::vec_unpack_64_uint7_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +TEST(test_bitpacking_128_uint7_values, PackUnpackAreSame) { + int unpacked_bytes = 128; + int packed_bytes = 112; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 7); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + uint8x16_t input4; + uint8x16_t input5; + uint8x16_t input6; + uint8x16_t input7; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + input4, input5, input6, input7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint7_values( + packed.data(), + input0, + input1, + input2, + input3, + input4, + input5, + input6, + input7); + torchao::bitpacking::internal::vec_unpack_128_uint7_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + EXPECT_EQ(input4[i], unpacked4[i]); + EXPECT_EQ(input5[i], unpacked5[i]); + EXPECT_EQ(input6[i], unpacked6[i]); + EXPECT_EQ(input7[i], unpacked7[i]); + } +} + // Universal bitpacking tests template void test_bitpacking_32_lowbit_values() { @@ -726,6 +835,7 @@ TEST_BITPACKING_32_LOWBIT_VALUES(3); TEST_BITPACKING_32_LOWBIT_VALUES(4); TEST_BITPACKING_32_LOWBIT_VALUES(5); TEST_BITPACKING_32_LOWBIT_VALUES(6); +TEST_BITPACKING_32_LOWBIT_VALUES(7); TEST_BITPACKING_64_LOWBIT_VALUES(1); TEST_BITPACKING_64_LOWBIT_VALUES(2); @@ -733,6 +843,7 @@ TEST_BITPACKING_64_LOWBIT_VALUES(3); TEST_BITPACKING_64_LOWBIT_VALUES(4); TEST_BITPACKING_64_LOWBIT_VALUES(5); TEST_BITPACKING_64_LOWBIT_VALUES(6); +TEST_BITPACKING_64_LOWBIT_VALUES(7); TEST_BITPACKING_128_LOWBIT_VALUES(1); TEST_BITPACKING_128_LOWBIT_VALUES(2); @@ -740,5 +851,6 @@ TEST_BITPACKING_128_LOWBIT_VALUES(3); TEST_BITPACKING_128_LOWBIT_VALUES(4); TEST_BITPACKING_128_LOWBIT_VALUES(5); TEST_BITPACKING_128_LOWBIT_VALUES(6); +TEST_BITPACKING_128_LOWBIT_VALUES(7); #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp index 9e5ea0e046..dfb61eb928 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp @@ -35,6 +35,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { DEFINE_OP(4); DEFINE_OP(5); DEFINE_OP(6); + DEFINE_OP(7); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -44,6 +45,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(4); DEFINE_CPU_IMPL(5); DEFINE_CPU_IMPL(6); + DEFINE_CPU_IMPL(7); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -53,4 +55,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(4); DEFINE_META_IMPL(5); DEFINE_META_IMPL(6); + DEFINE_META_IMPL(7); } diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp index 486ab3685b..1b79a5e035 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp @@ -36,3 +36,4 @@ DEFINE_OP(3); DEFINE_OP(4); DEFINE_OP(5); DEFINE_OP(6); +DEFINE_OP(7); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index 9c0832c5b7..f69e51e4c9 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -67,6 +67,7 @@ TORCH_LIBRARY(torchao, m) { DEFINE_OP(4); DEFINE_OP(5); DEFINE_OP(6); + DEFINE_OP(7); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -76,6 +77,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(4); DEFINE_CPU_IMPL(5); DEFINE_CPU_IMPL(6); + DEFINE_CPU_IMPL(7); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -85,4 +87,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(4); DEFINE_META_IMPL(5); DEFINE_META_IMPL(6); + DEFINE_META_IMPL(7); } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w7s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w7s.cpp new file mode 100644 index 0000000000..27bbcf3e38 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w7s.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_7bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w7sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w7sz.cpp new file mode 100644 index 0000000000..89bb234119 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w7sz.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_7bit_weight.out", _op_out); diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 3cd4c84ba7..18b18357bc 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -198,7 +198,7 @@ def forward(self, x): def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): try: - if nbit in [1, 2, 3, 4, 5, 6]: + if nbit in [1, 2, 3, 4, 5, 6, 7]: wzp_suffix = "" if has_weight_zeros else "0zp" return _Int8DynActIntxWeightQuantizedLinearNative( pack_weight_op=getattr(