-
Notifications
You must be signed in to change notification settings - Fork 318
Add support for KleidiAI int4 kernels on aarch64 Linux #2169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
9f49c52
Debugging ARM Neoverse-V1
vctrmn 01fe718
add comment to cmake
vctrmn 1490cb2
Merge branch 'pytorch:main' into main
vctrmn e4b0787
Debug NEOVERSE ARM
vctrmn 345055b
remove useless comments
vctrmn 73370ed
clean
vctrmn 36a90d0
clean
vctrmn 5a6e31b
debug
vctrmn 38b09f2
clean
vctrmn 239d263
load multiple potential paths
vctrmn 2ae0a6f
remove assertion
vctrmn 7a5b438
re-introduce assertion and define load_libtorchao_ops fn
vctrmn 501b080
add unit test to ensure
vctrmn 4dda54d
Ready for merge
vctrmn 4155db0
last test
vctrmn 9575890
fix
vctrmn 3a99eff
PR feedbacks
vctrmn b230845
debug
vctrmn 9b9748b
add comments
vctrmn 6209194
add ENABLE_ARM_NEON in build_torchao_ops
vctrmn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,13 @@ if (NOT CMAKE_BUILD_TYPE) | |
set(CMAKE_BUILD_TYPE Release) | ||
endif() | ||
|
||
# Platform options | ||
option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) | ||
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF) | ||
option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF) | ||
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) | ||
option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF) | ||
option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF) | ||
|
||
if(NOT TORCHAO_INCLUDE_DIRS) | ||
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..) | ||
|
@@ -28,19 +31,49 @@ if(NOT DEFINED TORCHAO_PARALLEL_BACKEND) | |
set(TORCHAO_PARALLEL_BACKEND aten_openmp) | ||
endif() | ||
|
||
include(CMakePrintHelpers) | ||
|
||
# Set default compiler options | ||
add_compile_options("-Wall" "-Werror" "-Wno-deprecated") | ||
|
||
include(CMakePrintHelpers) | ||
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") | ||
include_directories(${TORCHAO_INCLUDE_DIRS}) | ||
|
||
|
||
if(TORCHAO_BUILD_CPU_AARCH64) | ||
message(STATUS "Building with cpu/aarch64") | ||
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) | ||
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
|
||
# Set aarch64 compiler options | ||
if (CMAKE_SYSTEM_NAME STREQUAL "Linux") | ||
message(STATUS "Add aarch64 linux compiler options") | ||
add_compile_options( | ||
"-fPIC" | ||
"-Wno-error=unknown-pragmas" | ||
"-Wno-array-parameter" | ||
"-Wno-maybe-uninitialized" | ||
"-Wno-sign-compare" | ||
) | ||
|
||
# Since versions are hierarchical (each includes features from prior versions): | ||
# - dotprod is included by default in armv8.4-a and later | ||
# - i8mm is included by default in armv8.6-a and later | ||
if(TORCHAO_ENABLE_ARM_I8MM) | ||
message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)") | ||
add_compile_options("-march=armv8.6-a") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would be more conservative and drop the minor version to the first version allowed with i8mm i.e. |
||
elseif(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
message(STATUS "Using armv8.4-a (includes '+dotprod' flag)") | ||
add_compile_options("-march=armv8.4-a") | ||
endif() | ||
endif() | ||
|
||
if(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
message(STATUS "Building with ARM NEON dot product support") | ||
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
endif() | ||
|
||
if(TORCHAO_ENABLE_ARM_I8MM) | ||
message(STATUS "Building with ARM I8MM support") | ||
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) | ||
endif() | ||
|
||
# Defines torchao_kernels_aarch64 | ||
add_subdirectory(kernels/cpu/aarch64) | ||
|
@@ -51,26 +84,33 @@ if(TORCHAO_BUILD_CPU_AARCH64) | |
endif() | ||
endif() | ||
|
||
# Add quantized operation dir | ||
add_subdirectory(ops/linear_8bit_act_xbit_weight) | ||
add_subdirectory(ops/embedding_xbit) | ||
|
||
# ATen ops lib | ||
add_library(torchao_ops_aten SHARED) | ||
target_link_libraries( | ||
torchao_ops_aten PRIVATE | ||
torchao_ops_linear_8bit_act_xbit_weight_aten | ||
torchao_ops_embedding_xbit_aten | ||
) | ||
|
||
# Add MPS support if enabled | ||
if (TORCHAO_BUILD_MPS_OPS) | ||
message(STATUS "Building with MPS support") | ||
add_subdirectory(ops/mps) | ||
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten) | ||
endif() | ||
|
||
# Install ATen targets | ||
install( | ||
TARGETS torchao_ops_aten | ||
EXPORT _targets | ||
DESTINATION lib | ||
) | ||
|
||
# Build executorch lib if enabled | ||
if(TORCHAO_BUILD_EXECUTORCH_OPS) | ||
add_library(torchao_ops_executorch STATIC) | ||
target_link_libraries(torchao_ops_executorch PRIVATE | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from pathlib import Path | ||
from unittest.mock import MagicMock, patch | ||
|
||
|
||
class TestLibTorchAoOpsLoader(unittest.TestCase): | ||
def test_find_and_load_success(self): | ||
mock_paths = [Path("/test/path1")] | ||
mock_lib = MagicMock() | ||
mock_lib.__str__.return_value = "/test/path1/libtorchao_ops_aten.so" | ||
|
||
with patch("pathlib.Path.glob", return_value=[mock_lib]): | ||
with patch("torch.ops.load_library") as mock_load: | ||
from ..op_lib import find_and_load_libtorchao_ops | ||
|
||
find_and_load_libtorchao_ops(mock_paths) | ||
|
||
mock_load.assert_called_once_with("/test/path1/libtorchao_ops_aten.so") | ||
|
||
def test_no_library_found(self): | ||
mock_paths = [Path("/test/path1"), Path("/test/path2")] | ||
|
||
with patch("pathlib.Path.glob", return_value=[]): | ||
from ..op_lib import find_and_load_libtorchao_ops | ||
|
||
with self.assertRaises(FileNotFoundError): | ||
find_and_load_libtorchao_ops(mock_paths) | ||
|
||
def test_multiple_libraries_error(self): | ||
mock_paths = [Path("/test/path1")] | ||
mock_lib1 = MagicMock() | ||
mock_lib2 = MagicMock() | ||
mock_libs = [mock_lib1, mock_lib2] | ||
|
||
with patch("pathlib.Path.glob", return_value=mock_libs): | ||
from ..op_lib import find_and_load_libtorchao_ops | ||
|
||
try: | ||
find_and_load_libtorchao_ops(mock_paths) | ||
self.fail("Expected AssertionError was not raised") | ||
except AssertionError as e: | ||
expected_error_msg = f"Expected to find one libtorchao_ops_aten.* library at {mock_paths[0]}, but found 2" | ||
self.assertIn(expected_error_msg, str(e)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should TORCHAO_ENABLE_ARM_NEON_DOT on Arm mac