-
Notifications
You must be signed in to change notification settings - Fork 316
Add support for KleidiAI int4 kernels on aarch64 Linux #2169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2169
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 3 PendingAs of commit 6209194 with merge base 94e2e05 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
In my testing environment, I manually created a symbolic link to
Here it is : apt install libomp-dev -y
ln -s /usr/lib/llvm-*/lib/libomp.so [path-to-virtualenv]/lib/python*/site-packages/torch/lib/libomp.so setting up
demo.pyimport copy
import time
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
LlamaForCausalLM,
LlamaTokenizer,
)
import torchao
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
MappingType,
Target,
quantize_,
)
from torchao.quantization.granularity import PerGroup
model_id = "meta-llama/Llama-3.2-1B-Instruct"
def load_model_and_tokenizer() -> tuple[LlamaTokenizer, LlamaForCausalLM]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return tokenizer, model
def main() -> None:
print(f"\ntorch v{torch.__version__}")
print(f"torchao v{torchao.__version__}")
print("Loading tokenizer and model ...")
tokenizer, model = load_model_and_tokenizer()
print(f"tokenizer and model loaded on {model.device}")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "Can you explain quantum computing in simple terms?",
},
]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
chat_input: BatchEncoding = tokenizer(formatted_prompt, return_tensors="pt").to(
model.device,
)
# No optim inference
print("\n--- Running standard inference ---")
start_time = time.time()
with torch.no_grad():
chat_outputs = model.generate(**chat_input, max_new_tokens=100)
end_time = time.time()
inference_time = end_time - start_time
# Decode the generated tokens, excluding the prompt
input_ids: torch.Tensor = chat_input["input_ids"]
prompt_length: int = input_ids.shape[1]
response = tokenizer.decode(
chat_outputs[0][prompt_length:],
skip_special_tokens=True,
)
print(
f"----------------------------------\n{response}\n----------------------------------",
)
print(f"Inference time: {inference_time:.2f} seconds")
# KleidiAI optim
print("\n--- Attempting KleidiAI optimization ---")
quantized_model = copy.deepcopy(model)
quantize_(
quantized_model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
weight_mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.ASYMMETRIC,
weight_scale_dtype=torch.bfloat16,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
target=Target.KLEIDIAI
),
),
)
start_time = time.time()
with torch.no_grad():
chat_outputs_quantized = quantized_model.generate(
**chat_input, max_new_tokens=100
)
end_time = time.time()
inference_time_quantized = end_time - start_time
response_quantized = tokenizer.decode(
chat_outputs_quantized[0][prompt_length:],
skip_special_tokens=True,
)
print(
f"----------------------------------\n{response_quantized}\n----------------------------------",
)
print(f"Quantized inference time: {inference_time_quantized:.2f} seconds")
print(f"Speedup: {inference_time / inference_time_quantized:.2f}x")
# Print speedups
if inference_time > 0:
print(
f"Speedup with KleidiAI optimization: {inference_time / inference_time_quantized:.2f}x",
)
if __name__ == "__main__":
main() |
On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT? There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though. |
Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so? If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link? |
Thanks for the PR @vctrmn! It mostly looks good, but let's guard on neon dot flag and resolve the compile issue in CI. |
@metascroy I confirm that
Looking into the directories :
Would you recommend adding logic in the |
I had the same compile error in my ubuntu instance. I solved it by installing the nightly version of torch : I will take a look at this |
Hi @metascroy ! I feel that I need to clarify torchao building options before diving into Correct me, for ARM64 platforms:
When
Also, an important note is that on Linux aarch64, we need |
Yes, I think that summarizes the required changes. |
No, I don't think we should create a symlink in CMakeLists.txt. When you set, did you see a message about "Building with TORCHAO_PARALLEL_BACKEND=OPENMP" during compilation: pytorch/ao/torchao/experimental/Utils.cmake?lines=35. When TORCHAO_PARALLEL_BACKEND is set to OPENMP, it shouldn't need to link against OMP in PyTorch, so that is what I'm curious about. It's an argument on the cmakelists, so we might need to define it with -DTORCHAO_PARALLEL_BACKEND=OPENMP on linux (pytorch/ao/torchao/experimental/CMakeLists.txt?lines=27-29) |
Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:
The new build command is now: BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install . In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!! |
On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in
|
What do you think @digantdesai ? |
Looking good! Let's wait for CI + to see if @digantdesai has any feedback on various arm versions. |
It looks like cpuinfo_uarch_zen5 is still failing in CI. Let me try checking out your PR on my mac and see if I can debug. |
torchao/experimental/CMakeLists.txt
Outdated
add_compile_options("-Wall" "-Werror" "-Wno-deprecated") | ||
|
||
# Find PyTorch | ||
find_package(Torch REQUIRED) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vctrmn So I found these two new lines are the source of the CI compilation error.
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
Specifically include_directories(${TORCH_INCLUDE_DIRS}) is an issue because it includes an older version of cpu_info.h that doesn't have the zen5 defined in its header. But I'm curious why these got added in the first place. These are included in specific targets where they are needed, e.g., here: https://github.com/pytorch/ao/blob/main/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt
Does it compile for you without them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@metascroy Yes it compiles without them ! Thanks for pointing out the issue. I had initially added it while troubleshooting some build dependencies, but those issues were resolved through other changes
torchao/experimental/CMakeLists.txt
Outdated
@@ -42,6 +46,17 @@ if(TORCHAO_BUILD_CPU_AARCH64) | |||
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) | |||
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) | |||
|
|||
if (LINUX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vctrmn Can we change LINUX to if (CMAKE_SYSTEM_NAME STREQUAL "Linux") and move these right after the other compile options on line 33.
Can we guard "-march=armv8.4-a+dotprod" behind
if (TORCHAO_BUILD_CPU_AARCH64 AND TORCHAO_ENABLE_ARM_NEON_DOT)
add_compile_definitions("-march=armv8.4-a+dotprod")
endif()
Might require you -DTORCHAO_ENABLE_ARM_NEON_DOT in setup.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@metascroy I have addressed your suggestions!
For code organization, I have kept all arm64 specific code together within the TORCHAO_BUILD_CPU_AARCH64
block for better maintainability rather than moving the Linux flags near line 33. Because the compile options (-fPIC
, -Wno-error=unknown-pragmas
, etc.) are only necessary for arm64 linux builds when TORCHAO_BUILD_CPU_AARCH64
is enabled
Build OK with :
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .
and
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=0 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .
Does this suit you ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me! Thanks for the contribution.
Stamping, but I have one nit: can we use the value of TORCHAO_ENABLE_ARM_NEON_DOT passed from setup.py for Mac as well like we do for linux (in setup.py build_options.enable_arm_neon_dot should be set to true on Arm-based macs because they all support neon dot; for linux, you can rely on the environment variable).
I'm just thinking ahead to TORCHAO_ENABLE_ARM_I8MM (which we also support for KleidiAI). This is not supported on all mac hardware, and I'd like TORCHAO_ENABLE_ARM_NEON_DOT and TORCHAO_ENABLE_ARM_I8MM to be handled in the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! @metascroy I have updated the code : I tried to prepare for the TORCHAO_ENABLE_ARM_I8MM
enablement as well :
- Made
TORCHAO_ENABLE_ARM_NEON_DOT
andTORCHAO_ENABLE_ARM_I8MM
platform-independent (should work for both Linux and macOS) - Improved comments to explain the "hierarchical" architecture flags from the gcc
-march
documentation - Clarified why it uses "if-elseif" rather than two separate "if" statements
- Fixed my previous mistake with
-march=armv8.4-a+dotprod
- The+dotprod
is already included inarmv8.4-a
Let me know if you think I have added too much complexity, I can simplify by removing the I8MM activation feature if preferred
And tested it with again both : (but not with TORCHAO_ENABLE_ARM_I8MM=1
)
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=0 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping I8MM is fine, but looks like we introduced one CI failure for build_et_ops. I think you just need to add the new TORCHAO_ENABLE_ARM_NEON_DOT here: https://github.com/pytorch/ao/blob/main/torchao/experimental/build_torchao_ops.sh#L24
On second thought here, let's not anchor too much on GCC. Mac uses clang by default, so maybe we guard specific features like "-march=armv8.4-a+dotprod" behind flags like TORCHAO_ENABLE_ARM_NEON_DOT that are set in the environment. In future we can decide which ones to enable in setup.py based on hardware. See my comment here: #2169 (comment) |
self.parallel_backend = os.getenv("TORCHAO_PARALLEL_BACKEND", "aten_openmp") | ||
|
||
# TORCHAO_ENABLE_ARM_NEON_DOT enable ARM NEON dot instructions | ||
self.enable_arm_neon_dot = self._os_bool_var( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should TORCHAO_ENABLE_ARM_NEON_DOT on Arm mac
I wonder if we package it only for mac wheels.. |
# - i8mm is included by default in armv8.6-a and later | ||
if(TORCHAO_ENABLE_ARM_I8MM) | ||
message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)") | ||
add_compile_options("-march=armv8.6-a") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be more conservative and drop the minor version to the first version allowed with i8mm i.e. -march=v8.2-a+i8mm
here, that way, if users have an older machine with i8mm we don't accidently use something which user may not have. Same for dotprod.
Apologies for a delayed response. Replied. |
This PR adds support for using KleidiAI int4 kernels on aarch64 Linux systems. Previously, these kernels were only enabled on macOS ARM platforms, but with these changes, they can be properly built and loaded on any ARM64 Linux system with the appropriate features (NEON, dot product, etc.).
Changes
setup.py
to allow explicit building of arm kernels via theBUILD_TORCHAO_CPU
environment variableop_lib.py
to search in multiple potential installation pathsCMakeLists.txt
How to build
Users can build torchao with KleidiAI support on aarch64 Linux using:
Testing
Scaleway Ubuntu VM (4 CPU x 16 GB RAM - COPARM1-4C-16G)
Discussion Points
-march=armv8.4-a+dotprod
inthe CMakeLists.txt
. While this works for my ubuntu vm, we may want to implement a more flexible solution that detects the specific ARM features available ?Related issue
#2143