Skip to content

Add support for KleidiAI int4 kernels on aarch64 Linux #2169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 14, 2025

Conversation

vctrmn
Copy link
Contributor

@vctrmn vctrmn commented May 4, 2025

This PR adds support for using KleidiAI int4 kernels on aarch64 Linux systems. Previously, these kernels were only enabled on macOS ARM platforms, but with these changes, they can be properly built and loaded on any ARM64 Linux system with the appropriate features (NEON, dot product, etc.).

Changes

  • Modified setup.py to allow explicit building of arm kernels via the BUILD_TORCHAO_CPU environment variable
  • Updated library detection in op_lib.py to search in multiple potential installation paths
  • Fixed compiler warnings
  • Added appropriate compiler flags for aarch64 in CMakeLists.txt

How to build

Users can build torchao with KleidiAI support on aarch64 Linux using:

BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .

Testing

Scaleway Ubuntu VM (4 CPU x 16 GB RAM - COPARM1-4C-16G)

~# lscpu
Architecture:             aarch64
  CPU op-mode(s):         32-bit, 64-bit
  Byte Order:             Little Endian
CPU(s):                   4
  On-line CPU(s) list:    0-3
Vendor ID:                ARM
  BIOS Vendor ID:         QEMU
  Model name:             Neoverse-N1
    BIOS Model name:      virt-6.2  CPU @ 2.0GHz
    BIOS CPU family:      1
    Model:                1
    Thread(s) per core:   1
    Core(s) per socket:   4
    Socket(s):            1
    Stepping:             r3p1
    BogoMIPS:             50.00
    Flags:                fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop asimddp ssbs

Discussion Points

  • I have a doubt about fixing the architecture flags to -march=armv8.4-a+dotprod in the CMakeLists.txt. While this works for my ubuntu vm, we may want to implement a more flexible solution that detects the specific ARM features available ?
  • This is a quick implementation to get things working, but you might want to discuss the right way to implement and use build env variables in setup.py for a more robust solution ?

Related issue

#2143

Copy link

pytorch-bot bot commented May 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2169

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 3 Pending

As of commit 6209194 with merge base 94e2e05 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 4, 2025
@vctrmn
Copy link
Contributor Author

vctrmn commented May 4, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

      [ 48%] Built target torchao_ops_linear_8bit_act_xbit_weight_aten
      gmake[2]: *** No rule to make target '/root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so', needed by '/root/ao/build/lib.linux-aarch64-cpython-312/torchao/libtorchao_ops_aten.so'.  Stop.
      gmake[1]: *** [CMakeFiles/Makefile2:287: CMakeFiles/torchao_ops_aten.dir/all] Error 2
      gmake: *** [Makefile:136: all] Error 2
      Traceback (most recent call last):

Here it is :

apt install libomp-dev -y
ln -s /usr/lib/llvm-*/lib/libomp.so [path-to-virtualenv]/lib/python*/site-packages/torch/lib/libomp.so
setting up
apt update
apt install gcc g++ cmake ninja-build build-essential python3-pip python3-venv google-perftools -y

git clone https://github.com/vctrmn/ao.git
cd ao

python3 -m venv venv
source venv/bin/activate
pip install wheel setuptools
pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install numpy

apt install libomp-dev -y
ln -s /usr/lib/llvm-18/lib/libomp.so /root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so

BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .

pip install transformers
huggingface-cli login
python demo.py
demo.py
import copy
import time

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BatchEncoding,
    LlamaForCausalLM,
    LlamaTokenizer,
)

import torchao
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
    PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.quantization.quant_api import (
    Int8DynamicActivationIntxWeightConfig,
    MappingType,
    Target,
    quantize_,
)
from torchao.quantization.granularity import PerGroup

model_id = "meta-llama/Llama-3.2-1B-Instruct"


def load_model_and_tokenizer() -> tuple[LlamaTokenizer, LlamaForCausalLM]:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    return tokenizer, model


def main() -> None:
    print(f"\ntorch v{torch.__version__}")
    print(f"torchao v{torchao.__version__}")

    print("Loading tokenizer and model ...")
    tokenizer, model = load_model_and_tokenizer()
    print(f"tokenizer and model loaded on {model.device}")

    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": "Can you explain quantum computing in simple terms?",
        },
    ]

    formatted_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    chat_input: BatchEncoding = tokenizer(formatted_prompt, return_tensors="pt").to(
        model.device,
    )

    # No optim inference
    print("\n--- Running standard inference ---")
    start_time = time.time()

    with torch.no_grad():
        chat_outputs = model.generate(**chat_input, max_new_tokens=100)

    end_time = time.time()
    inference_time = end_time - start_time

    # Decode the generated tokens, excluding the prompt
    input_ids: torch.Tensor = chat_input["input_ids"]
    prompt_length: int = input_ids.shape[1]
    response = tokenizer.decode(
        chat_outputs[0][prompt_length:],
        skip_special_tokens=True,
    )

    print(
        f"----------------------------------\n{response}\n----------------------------------",
    )
    print(f"Inference time: {inference_time:.2f} seconds")

    # KleidiAI optim
    print("\n--- Attempting KleidiAI optimization ---")
    quantized_model = copy.deepcopy(model)

    quantize_(
        quantized_model,
        Int8DynamicActivationIntxWeightConfig(
            weight_dtype=torch.int4,
            weight_granularity=PerGroup(32),
            weight_mapping_type=MappingType.SYMMETRIC,
            act_mapping_type=MappingType.ASYMMETRIC,
            weight_scale_dtype=torch.bfloat16,
            layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
                target=Target.KLEIDIAI
            ),
        ),
    )

    start_time = time.time()

    with torch.no_grad():
        chat_outputs_quantized = quantized_model.generate(
            **chat_input, max_new_tokens=100
        )

    end_time = time.time()
    inference_time_quantized = end_time - start_time

    response_quantized = tokenizer.decode(
        chat_outputs_quantized[0][prompt_length:],
        skip_special_tokens=True,
    )

    print(
        f"----------------------------------\n{response_quantized}\n----------------------------------",
    )
    print(f"Quantized inference time: {inference_time_quantized:.2f} seconds")
    print(f"Speedup: {inference_time / inference_time_quantized:.2f}x")

    # Print speedups
    if inference_time > 0:
        print(
            f"Speedup with KleidiAI optimization: {inference_time / inference_time_quantized:.2f}x",
        )


if __name__ == "__main__":
    main()

@metascroy
Copy link
Contributor

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?

There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

@metascroy
Copy link
Contributor

metascroy commented May 6, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so?

If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link?

@metascroy
Copy link
Contributor

Thanks for the PR @vctrmn! It mostly looks good, but let's guard on neon dot flag and resolve the compile issue in CI.

@vctrmn
Copy link
Contributor Author

vctrmn commented May 6, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so?

If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link?

@metascroy I confirm that libomp.so is indeed not bundled with torch in my aarch64 ubuntu instance, and setting TORCHAO_PARALLEL_BACKEND=OPENMP doesn't solve the issue.

~/ao# BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .
...
      [ 48%] Built target torchao_ops_linear_8bit_act_xbit_weight_aten
      gmake[2]: *** No rule to make target '/root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so', needed by '/root/ao/build/lib.linux-aarch64-cpython-312/torchao/libtorchao_ops_aten.so'.  Stop.
      gmake[1]: *** [CMakeFiles/Makefile2:287: CMakeFiles/torchao_ops_aten.dir/all] Error 2
      gmake: *** [Makefile:136: all] Error 2

Looking into the directories :

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch/lib/
libc10.so  libshm  libshm.so  libshm_windows  libtorch.so  libtorch_cpu.so  libtorch_global_deps.so  libtorch_python.so

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch.libs/
libarm_compute-d924ca35.so  libarm_compute_graph-17c2200a.so  libgfortran-8a9a71bc.so.5.0.0  libgomp-947d5fa1.so.1.0.0  libopenblasp-r0-0d78ce56.3.29.so

Would you recommend adding logic in the CMakeLists.txt to automatically handle this (finding libomp.so + creating the symlink) ? Or should this be addressed at the pytorch repo level ?

@vctrmn
Copy link
Contributor Author

vctrmn commented May 6, 2025

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?

There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

I had the same compile error in my ubuntu instance. I solved it by installing the nightly version of torch : pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu

I will take a look at this

@vctrmn
Copy link
Contributor Author

vctrmn commented May 6, 2025

Hi @metascroy ! I feel that I need to clarify torchao building options before diving into setup.py and CMakeLists.txt

Correct me, for ARM64 platforms:

  • On macOS: TORCHAO_BUILD_CPU_AARCH64 is enabled by default
  • On Linux: TORCHAO_BUILD_CPU_AARCH64 must be explicitly enabled

When TORCHAO_BUILD_CPU_AARCH64 is enabled, the following additional options become available:

  • TORCHAO_BUILD_KLEIDIAI (should be explicit in args)
  • TORCHAO_ENABLE_ARM_NEON_DOT (this option is guard to enable -march=armv8.4-a+dotprod on Linux, but should be enabled by default on macOS)

Also, an important note is that on Linux aarch64, we need BUILD_TORCHAO_CPU=1

@metascroy
Copy link
Contributor

Hi @metascroy ! I feel that I need to clarify torchao building options before diving into setup.py and CMakeLists.txt

Correct me, for ARM64 platforms:

  • On macOS: TORCHAO_BUILD_CPU_AARCH64 is enabled by default
  • On Linux: TORCHAO_BUILD_CPU_AARCH64 must be explicitly enabled

When TORCHAO_BUILD_CPU_AARCH64 is enabled, the following additional options become available:

  • TORCHAO_BUILD_KLEIDIAI (should be explicit in args)
  • TORCHAO_ENABLE_ARM_NEON_DOT (this option is guard to enable -march=armv8.4-a+dotprod on Linux, but should be enabled by default on macOS)

Also, an important note is that on Linux aarch64, we need BUILD_TORCHAO_CPU=1

Yes, I think that summarizes the required changes.

@metascroy
Copy link
Contributor

metascroy commented May 7, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so?
If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link?

@metascroy I confirm that libomp.so is indeed not bundled with torch in my aarch64 ubuntu instance, and setting TORCHAO_PARALLEL_BACKEND=OPENMP doesn't solve the issue.

~/ao# BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .
...
      [ 48%] Built target torchao_ops_linear_8bit_act_xbit_weight_aten
      gmake[2]: *** No rule to make target '/root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so', needed by '/root/ao/build/lib.linux-aarch64-cpython-312/torchao/libtorchao_ops_aten.so'.  Stop.
      gmake[1]: *** [CMakeFiles/Makefile2:287: CMakeFiles/torchao_ops_aten.dir/all] Error 2
      gmake: *** [Makefile:136: all] Error 2

Looking into the directories :

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch/lib/
libc10.so  libshm  libshm.so  libshm_windows  libtorch.so  libtorch_cpu.so  libtorch_global_deps.so  libtorch_python.so

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch.libs/
libarm_compute-d924ca35.so  libarm_compute_graph-17c2200a.so  libgfortran-8a9a71bc.so.5.0.0  libgomp-947d5fa1.so.1.0.0  libopenblasp-r0-0d78ce56.3.29.so

Would you recommend adding logic in the CMakeLists.txt to automatically handle this (finding libomp.so + creating the symlink) ? Or should this be addressed at the pytorch repo level ?

No, I don't think we should create a symlink in CMakeLists.txt. When you set, did you see a message about "Building with TORCHAO_PARALLEL_BACKEND=OPENMP" during compilation: pytorch/ao/torchao/experimental/Utils.cmake?lines=35.

When TORCHAO_PARALLEL_BACKEND is set to OPENMP, it shouldn't need to link against OMP in PyTorch, so that is what I'm curious about.

It's an argument on the cmakelists, so we might need to define it with -DTORCHAO_PARALLEL_BACKEND=OPENMP on linux (pytorch/ao/torchao/experimental/CMakeLists.txt?lines=27-29)

@vctrmn
Copy link
Contributor Author

vctrmn commented May 7, 2025

Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:

  • Renamed build_torchao_experimental to build_macos_arm_auto for better clarity
  • Added BUILD_TORCHAO_EXPERIMENTAL to enable the experimental features on non-macOS platforms

The new build command is now:

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .

In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!!

@vctrmn
Copy link
Contributor Author

vctrmn commented May 7, 2025

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?

There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in setup.py to handle the architecture specification more flexibly ?

arch value Architecture Includes by default
‘armv8-a’ Armv8-A ‘+fp’, ‘+simd’
‘armv8.1-a’ Armv8.1-A ‘armv8-a’, ‘+crc’, ‘+lse’, ‘+rdma’
‘armv8.2-a’ Armv8.2-A ‘armv8.1-a’
‘armv8.3-a’ Armv8.3-A ‘armv8.2-a’, ‘+pauth’, ‘+fcma’, ‘+jscvt’
‘armv8.4-a’ Armv8.4-A ‘armv8.3-a’, ‘+flagm’, ‘+fp16fml’, ‘+dotprod’, ‘+rcpc2’
‘armv8.5-a’ Armv8.5-A ‘armv8.4-a’, ‘+sb’, ‘+ssbs’, ‘+predres’, ‘+frintts’, ‘+flagm2’
‘armv8.6-a’ Armv8.6-A ‘armv8.5-a’, ‘+bf16’, ‘+i8mm’
‘armv8.7-a’ Armv8.7-A ‘armv8.6-a’, ‘+wfxt’, ‘+xs’
‘armv8.8-a’ Armv8.8-a ‘armv8.7-a’, ‘+mops’
‘armv8.9-a’ Armv8.9-a ‘armv8.8-a’
‘armv9-a’ Armv9-A ‘armv8.5-a’, ‘+sve’, ‘+sve2’
‘armv9.1-a’ Armv9.1-A ‘armv9-a’, ‘+bf16’, ‘+i8mm’
‘armv9.2-a’ Armv9.2-A ‘armv9.1-a’, ‘+wfxt’, ‘+xs’
‘armv9.3-a’ Armv9.3-A ‘armv9.2-a’, ‘+mops’
‘armv9.4-a’ Armv9.4-A ‘armv9.3-a’, ‘+sve2p1’
‘armv9.5-a’ Armv9.4-A ‘armv9.4-a’, ‘cpa’, ‘+faminmax’, ‘+lut’
‘armv8-r’ Armv8-R ‘armv8-r’

@metascroy
Copy link
Contributor

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?
There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in setup.py to handle the architecture specification more flexibly ?

arch value Architecture Includes by default
‘armv8-a’ Armv8-A ‘+fp’, ‘+simd’
‘armv8.1-a’ Armv8.1-A ‘armv8-a’, ‘+crc’, ‘+lse’, ‘+rdma’
‘armv8.2-a’ Armv8.2-A ‘armv8.1-a’
‘armv8.3-a’ Armv8.3-A ‘armv8.2-a’, ‘+pauth’, ‘+fcma’, ‘+jscvt’
‘armv8.4-a’ Armv8.4-A ‘armv8.3-a’, ‘+flagm’, ‘+fp16fml’, ‘+dotprod’, ‘+rcpc2’
‘armv8.5-a’ Armv8.5-A ‘armv8.4-a’, ‘+sb’, ‘+ssbs’, ‘+predres’, ‘+frintts’, ‘+flagm2’
‘armv8.6-a’ Armv8.6-A ‘armv8.5-a’, ‘+bf16’, ‘+i8mm’
‘armv8.7-a’ Armv8.7-A ‘armv8.6-a’, ‘+wfxt’, ‘+xs’
‘armv8.8-a’ Armv8.8-a ‘armv8.7-a’, ‘+mops’
‘armv8.9-a’ Armv8.9-a ‘armv8.8-a’
‘armv9-a’ Armv9-A ‘armv8.5-a’, ‘+sve’, ‘+sve2’
‘armv9.1-a’ Armv9.1-A ‘armv9-a’, ‘+bf16’, ‘+i8mm’
‘armv9.2-a’ Armv9.2-A ‘armv9.1-a’, ‘+wfxt’, ‘+xs’
‘armv9.3-a’ Armv9.3-A ‘armv9.2-a’, ‘+mops’
‘armv9.4-a’ Armv9.4-A ‘armv9.3-a’, ‘+sve2p1’
‘armv9.5-a’ Armv9.4-A ‘armv9.4-a’, ‘cpa’, ‘+faminmax’, ‘+lut’
‘armv8-r’ Armv8-R ‘armv8-r’

What do you think @digantdesai ?

@metascroy
Copy link
Contributor

Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:

  • Renamed build_torchao_experimental to build_macos_arm_auto for better clarity
  • Added BUILD_TORCHAO_EXPERIMENTAL to enable the experimental features on non-macOS platforms

The new build command is now:

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .

In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!!

Looking good! Let's wait for CI + to see if @digantdesai has any feedback on various arm versions.

@metascroy
Copy link
Contributor

It looks like cpuinfo_uarch_zen5 is still failing in CI. Let me try checking out your PR on my mac and see if I can debug.

add_compile_options("-Wall" "-Werror" "-Wno-deprecated")

# Find PyTorch
find_package(Torch REQUIRED)
Copy link
Contributor

@metascroy metascroy May 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vctrmn So I found these two new lines are the source of the CI compilation error.

find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})

Specifically include_directories(${TORCH_INCLUDE_DIRS}) is an issue because it includes an older version of cpu_info.h that doesn't have the zen5 defined in its header. But I'm curious why these got added in the first place. These are included in specific targets where they are needed, e.g., here: https://github.com/pytorch/ao/blob/main/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt

Does it compile for you without them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Yes it compiles without them ! Thanks for pointing out the issue. I had initially added it while troubleshooting some build dependencies, but those issues were resolved through other changes

@@ -42,6 +46,17 @@ if(TORCHAO_BUILD_CPU_AARCH64)
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT)

if (LINUX)
Copy link
Contributor

@metascroy metascroy May 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vctrmn Can we change LINUX to if (CMAKE_SYSTEM_NAME STREQUAL "Linux") and move these right after the other compile options on line 33.

Can we guard "-march=armv8.4-a+dotprod" behind

if (TORCHAO_BUILD_CPU_AARCH64 AND TORCHAO_ENABLE_ARM_NEON_DOT)
    add_compile_definitions("-march=armv8.4-a+dotprod")
endif()

Might require you -DTORCHAO_ENABLE_ARM_NEON_DOT in setup.py.

Copy link
Contributor Author

@vctrmn vctrmn May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy I have addressed your suggestions!

For code organization, I have kept all arm64 specific code together within the TORCHAO_BUILD_CPU_AARCH64 block for better maintainability rather than moving the Linux flags near line 33. Because the compile options (-fPIC, -Wno-error=unknown-pragmas, etc.) are only necessary for arm64 linux builds when TORCHAO_BUILD_CPU_AARCH64 is enabled

Build OK with :

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .

and

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=0 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .

Does this suit you ?

Copy link
Contributor

@metascroy metascroy May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me! Thanks for the contribution.

Stamping, but I have one nit: can we use the value of TORCHAO_ENABLE_ARM_NEON_DOT passed from setup.py for Mac as well like we do for linux (in setup.py build_options.enable_arm_neon_dot should be set to true on Arm-based macs because they all support neon dot; for linux, you can rely on the environment variable).

I'm just thinking ahead to TORCHAO_ENABLE_ARM_I8MM (which we also support for KleidiAI). This is not supported on all mac hardware, and I'd like TORCHAO_ENABLE_ARM_NEON_DOT and TORCHAO_ENABLE_ARM_I8MM to be handled in the same way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! @metascroy I have updated the code : I tried to prepare for the TORCHAO_ENABLE_ARM_I8MM enablement as well :

  • Made TORCHAO_ENABLE_ARM_NEON_DOT and TORCHAO_ENABLE_ARM_I8MM platform-independent (should work for both Linux and macOS)
  • Improved comments to explain the "hierarchical" architecture flags from the gcc -march documentation
  • Clarified why it uses "if-elseif" rather than two separate "if" statements
  • Fixed my previous mistake with -march=armv8.4-a+dotprod - The +dotprod is already included in armv8.4-a

Let me know if you think I have added too much complexity, I can simplify by removing the I8MM activation feature if preferred

And tested it with again both : (but not with TORCHAO_ENABLE_ARM_I8MM=1)

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=0 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install -v .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping I8MM is fine, but looks like we introduced one CI failure for build_et_ops. I think you just need to add the new TORCHAO_ENABLE_ARM_NEON_DOT here: https://github.com/pytorch/ao/blob/main/torchao/experimental/build_torchao_ops.sh#L24

@metascroy
Copy link
Contributor

On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in setup.py to handle the architecture specification more flexibly ?

On second thought here, let's not anchor too much on GCC. Mac uses clang by default, so maybe we guard specific features like "-march=armv8.4-a+dotprod" behind flags like TORCHAO_ENABLE_ARM_NEON_DOT that are set in the environment. In future we can decide which ones to enable in setup.py based on hardware. See my comment here: #2169 (comment)

self.parallel_backend = os.getenv("TORCHAO_PARALLEL_BACKEND", "aten_openmp")

# TORCHAO_ENABLE_ARM_NEON_DOT enable ARM NEON dot instructions
self.enable_arm_neon_dot = self._os_bool_var(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should TORCHAO_ENABLE_ARM_NEON_DOT on Arm mac

@metascroy metascroy merged commit 4300079 into pytorch:main May 14, 2025
33 of 34 checks passed
@metascroy metascroy mentioned this pull request May 15, 2025
@digantdesai
Copy link
Contributor

is indeed not bundled with torch in my aarch64 ubuntu instance

I wonder if we package it only for mac wheels..

# - i8mm is included by default in armv8.6-a and later
if(TORCHAO_ENABLE_ARM_I8MM)
message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)")
add_compile_options("-march=armv8.6-a")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be more conservative and drop the minor version to the first version allowed with i8mm i.e. -march=v8.2-a+i8mm here, that way, if users have an older machine with i8mm we don't accidently use something which user may not have. Same for dotprod.

@digantdesai
Copy link
Contributor

Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:

  • Renamed build_torchao_experimental to build_macos_arm_auto for better clarity
  • Added BUILD_TORCHAO_EXPERIMENTAL to enable the experimental features on non-macOS platforms

The new build command is now:

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .

In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!!

Looking good! Let's wait for CI + to see if @digantdesai has any feedback on various arm versions.

Apologies for a delayed response. Replied.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants