Skip to content

Add support for KleidiAI int4 kernels on aarch64 Linux #2169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def read_version(file_path="version.txt"):

import platform

build_torchao_experimental = (
build_macos_arm_auto = (
use_cpp == "1"
and platform.machine().startswith("arm64")
and platform.system() == "Darwin"
Expand Down Expand Up @@ -121,8 +121,33 @@ def __init__(self):
"TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
)

# TORCHAO_PARALLEL_BACKEND specifies which parallel backend to use
# Possible values: aten_openmp, executorch, openmp, pthreadpool, single_threaded
self.parallel_backend = os.getenv("TORCHAO_PARALLEL_BACKEND", "aten_openmp")

# TORCHAO_ENABLE_ARM_NEON_DOT enable ARM NEON Dot Product extension
# Enabled by default on macOS silicon
self.enable_arm_neon_dot = self._os_bool_var(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should TORCHAO_ENABLE_ARM_NEON_DOT on Arm mac

"TORCHAO_ENABLE_ARM_NEON_DOT",
default=(self._is_arm64() and self._is_macos()),
)
if self.enable_arm_neon_dot:
assert self.build_cpu_aarch64, (
"TORCHAO_ENABLE_ARM_NEON_DOT requires TORCHAO_BUILD_CPU_AARCH64 be set"
)

# TORCHAO_ENABLE_ARM_I8MM enable ARM 8-bit Integer Matrix Multiply instructions
# Not enabled by default on macOS as not all silicon mac supports it
self.enable_arm_i8mm = self._os_bool_var(
"TORCHAO_ENABLE_ARM_I8MM", default=False
)
if self.enable_arm_i8mm:
assert self.build_cpu_aarch64, (
"TORCHAO_ENABLE_ARM_I8MM requires TORCHAO_BUILD_CPU_AARCH64 be set"
)

def _is_arm64(self) -> bool:
return platform.machine().startswith("arm64")
return platform.machine().startswith("arm64") or platform.machine() == "aarch64"

def _is_macos(self) -> bool:
return platform.system() == "Darwin"
Expand Down Expand Up @@ -498,7 +523,8 @@ def get_extensions():
)
)

if build_torchao_experimental:
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
build_options = BuildOptions()

def bool_to_on_off(value):
Expand All @@ -518,6 +544,9 @@ def bool_to_on_off(value):
f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}",
f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}",
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
f"-DTORCHAO_ENABLE_ARM_NEON_DOT={bool_to_on_off(build_options.enable_arm_neon_dot)}",
f"-DTORCHAO_ENABLE_ARM_I8MM={bool_to_on_off(build_options.enable_arm_i8mm)}",
f"-DTORCHAO_PARALLEL_BACKEND={build_options.parallel_backend}",
"-DTorch_DIR=" + torch_dir,
]
+ (
Expand Down
48 changes: 44 additions & 4 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

# Platform options
option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)
option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF)
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF)
option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF)

if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
Expand All @@ -28,19 +31,49 @@ if(NOT DEFINED TORCHAO_PARALLEL_BACKEND)
set(TORCHAO_PARALLEL_BACKEND aten_openmp)
endif()

include(CMakePrintHelpers)

# Set default compiler options
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")

include(CMakePrintHelpers)
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})


if(TORCHAO_BUILD_CPU_AARCH64)
message(STATUS "Building with cpu/aarch64")
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT)

# Set aarch64 compiler options
if (CMAKE_SYSTEM_NAME STREQUAL "Linux")
message(STATUS "Add aarch64 linux compiler options")
add_compile_options(
"-fPIC"
"-Wno-error=unknown-pragmas"
"-Wno-array-parameter"
"-Wno-maybe-uninitialized"
"-Wno-sign-compare"
)

# Since versions are hierarchical (each includes features from prior versions):
# - dotprod is included by default in armv8.4-a and later
# - i8mm is included by default in armv8.6-a and later
if(TORCHAO_ENABLE_ARM_I8MM)
message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)")
add_compile_options("-march=armv8.6-a")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be more conservative and drop the minor version to the first version allowed with i8mm i.e. -march=v8.2-a+i8mm here, that way, if users have an older machine with i8mm we don't accidently use something which user may not have. Same for dotprod.

elseif(TORCHAO_ENABLE_ARM_NEON_DOT)
message(STATUS "Using armv8.4-a (includes '+dotprod' flag)")
add_compile_options("-march=armv8.4-a")
endif()
endif()

if(TORCHAO_ENABLE_ARM_NEON_DOT)
message(STATUS "Building with ARM NEON dot product support")
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT)
endif()

if(TORCHAO_ENABLE_ARM_I8MM)
message(STATUS "Building with ARM I8MM support")
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
endif()

# Defines torchao_kernels_aarch64
add_subdirectory(kernels/cpu/aarch64)
Expand All @@ -51,26 +84,33 @@ if(TORCHAO_BUILD_CPU_AARCH64)
endif()
endif()

# Add quantized operation dir
add_subdirectory(ops/linear_8bit_act_xbit_weight)
add_subdirectory(ops/embedding_xbit)

# ATen ops lib
add_library(torchao_ops_aten SHARED)
target_link_libraries(
torchao_ops_aten PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_aten
torchao_ops_embedding_xbit_aten
)

# Add MPS support if enabled
if (TORCHAO_BUILD_MPS_OPS)
message(STATUS "Building with MPS support")
add_subdirectory(ops/mps)
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
endif()

# Install ATen targets
install(
TARGETS torchao_ops_aten
EXPORT _targets
DESTINATION lib
)

# Build executorch lib if enabled
if(TORCHAO_BUILD_EXECUTORCH_OPS)
add_library(torchao_ops_executorch STATIC)
target_link_libraries(torchao_ops_executorch PRIVATE
Expand Down
1 change: 1 addition & 0 deletions torchao/experimental/build_torchao_ops.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
-DTORCHAO_BUILD_EXECUTORCH_OPS="${TORCHAO_BUILD_EXECUTORCH_OPS}" \
-DTORCHAO_BUILD_CPU_AARCH64=ON \
-DTORCHAO_ENABLE_ARM_NEON_DOT=ON \
-S . \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
#include <cassert>
#include <cstring>
#include <cstdint>

// Interleaves data across channels (row/column) and groups.
// Each channel is the same size (vals_per_channel) and is
Expand Down
43 changes: 36 additions & 7 deletions torchao/experimental/op_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,44 @@
from torch import Tensor
from torch.library import impl

# Load C++ ops
lib_path = Path(__file__).parent.parent
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
assert len(libs) == 1, (
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
)
torch.ops.load_library(str(libs[0]))
# Load C++ ops - use multiple potential paths
potential_paths = [
# Standard path from the module location
Path(__file__).parent.parent,
# Site-packages installation path
Path(torch.__file__).parent.parent / "torchao",
# For editable installs
Path(__file__).parent.parent.parent / "torchao",
]


def find_and_load_libtorchao_ops(potential_paths):
for lib_path in potential_paths:
libs = list(lib_path.glob("libtorchao_ops_aten.*"))

if not libs:
continue

assert len(libs) == 1, (
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
)

target_lib = libs[0]
print(f"Found library at: {target_lib}")

try:
torch.ops.load_library(str(target_lib))
return
except Exception as e:
print(f"Error loading library from {target_lib}: {e}")

raise FileNotFoundError(
"Could not find libtorchao_ops_aten library in any of the provided paths"
)


find_and_load_libtorchao_ops(potential_paths)

# Define meta ops. To support dynamic shapes, some meta ops need to
# be defined in python instead of C++.
torchao_lib = torch.library.Library("torchao", "IMPL")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ struct UKernelConfig {
TORCHAO_CHECK(pack_weights != nullptr || pack_weights_with_lut != nullptr, "pack_weights or pack_weights_with_lut must be set");

bool linear_configs_set = true; // first linear config must be set
for (int i = 0; i < linear_configs.size(); i++) {
for (size_t i = 0; i < linear_configs.size(); i++) {
if (linear_configs_set) {
TORCHAO_CHECK(
linear_configs[i].m_step >= 1,
Expand Down Expand Up @@ -225,7 +225,7 @@ struct UKernelConfig {
assert(m >= 1);
assert(linear_configs[0].m_step >= 1);

int i = 0;
size_t i = 0;
while (i + 1 < linear_configs.size() && linear_configs[i + 1].m_step >= 1 &&
linear_configs[i + 1].m_step <= m) {
assert(linear_configs[i].m_step < linear_configs[i + 1].m_step);
Expand All @@ -235,7 +235,7 @@ struct UKernelConfig {
assert(i < linear_configs.size());
assert(linear_configs[i].m_step >= 1);
assert(i == 0 || linear_configs[i].m_step <= m);
return i;
return static_cast<int>(i);
}
};

Expand Down
8 changes: 4 additions & 4 deletions torchao/experimental/ops/packed_weights_header.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PackedWeightsHeader {
auto header = reinterpret_cast<int*>(packed_weights);
header[0] = magic;
header[1] = static_cast<int>(type);
for (int i = 0; i < params.size(); i++) {
for (size_t i = 0; i < params.size(); i++) {
header[i + 2] = params[i];
}
}
Expand All @@ -52,7 +52,7 @@ class PackedWeightsHeader {
auto header = reinterpret_cast<const int*>(packed_weights);
assert(header[0] == PackedWeightsHeader::magic);
params_type params;
for (int i = 0; i < params.size(); i++) {
for (size_t i = 0; i < params.size(); i++) {
params[i] = header[i + 2];
}
return PackedWeightsHeader(
Expand All @@ -63,7 +63,7 @@ class PackedWeightsHeader {
if (type != other.type) {
return false;
}
for (int i = 0; i < params.size(); i++) {
for (size_t i = 0; i < params.size(); i++) {
if (params[i] != other.params[i]) {
return false;
}
Expand All @@ -79,7 +79,7 @@ namespace std {
struct hash<torchao::ops::PackedWeightsHeader> {
std::size_t operator()(const torchao::ops::PackedWeightsHeader& f) const {
std::size_t hash = std::hash<int>()(static_cast<int>(f.type));
for (int i = 0; i < f.params.size(); i++) {
for (size_t i = 0; i < f.params.size(); i++) {
hash ^= std::hash<int>()(f.params[i]);
}
return hash;
Expand Down
2 changes: 1 addition & 1 deletion torchao/experimental/ops/parallel-aten-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// LICENSE file in the root directory of this source tree.

#pragma once
#include <Aten/Parallel.h>
#include <ATen/Parallel.h>
#include <torch/library.h>
#include <torch/torch.h>

Expand Down
53 changes: 53 additions & 0 deletions torchao/experimental/tests/test_load_libtorchao_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch


class TestLibTorchAoOpsLoader(unittest.TestCase):
def test_find_and_load_success(self):
mock_paths = [Path("/test/path1")]
mock_lib = MagicMock()
mock_lib.__str__.return_value = "/test/path1/libtorchao_ops_aten.so"

with patch("pathlib.Path.glob", return_value=[mock_lib]):
with patch("torch.ops.load_library") as mock_load:
from ..op_lib import find_and_load_libtorchao_ops

find_and_load_libtorchao_ops(mock_paths)

mock_load.assert_called_once_with("/test/path1/libtorchao_ops_aten.so")

def test_no_library_found(self):
mock_paths = [Path("/test/path1"), Path("/test/path2")]

with patch("pathlib.Path.glob", return_value=[]):
from ..op_lib import find_and_load_libtorchao_ops

with self.assertRaises(FileNotFoundError):
find_and_load_libtorchao_ops(mock_paths)

def test_multiple_libraries_error(self):
mock_paths = [Path("/test/path1")]
mock_lib1 = MagicMock()
mock_lib2 = MagicMock()
mock_libs = [mock_lib1, mock_lib2]

with patch("pathlib.Path.glob", return_value=mock_libs):
from ..op_lib import find_and_load_libtorchao_ops

try:
find_and_load_libtorchao_ops(mock_paths)
self.fail("Expected AssertionError was not raised")
except AssertionError as e:
expected_error_msg = f"Expected to find one libtorchao_ops_aten.* library at {mock_paths[0]}, but found 2"
self.assertIn(expected_error_msg, str(e))


if __name__ == "__main__":
unittest.main()
Loading