Skip to content

[XNNPACK] Add support for Linear fused BatchNorm #11805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 11, 2025

Conversation

keyprocedure
Copy link
Contributor

@keyprocedure keyprocedure commented Jun 19, 2025

Summary

These changes implement a fusion pass in the XNNPACK partitioner to support linear + batchnorm operations. This pass involves identifying and combining each linear node that exclusively precedes a batchnorm node in the Export IR graph. Fusion occurs by updating the linear node's weight and bias with a new fused weight and bias computed from the linear and batchnorm parameters. In the case of linear nodes without bias, the fused bias is added as a new parameter to the linear node. All users of the batchnorm output are then redirected to the fused linear node, and the batchnorm node is removed from the graph.

This linear + batchnorm pass follows the existing implementation pattern of the convolution + batchnorm pass. These fusion passes fold batchnorm into the preceding convolution or linear ops during export. This allows the XNNPACK backend to run a single fused operation at inference, reducing memory usage and latency without affecting model accuracy.

Note: The current linear + batchnorm fusion implementation supports FP32 only. Quantized support is planned for a future PR to TorchAO.

Fixes #11587

Test plan

Tests were added to verify that linear + batchnorm fusion occurs for FP32 models when the linear layer has a single user, and is skipped for linear layers with multiple users. Both linear cases, with and without bias, are tested. A separate test ensures that standalone batchnorm layers are not partitioned, since XNNPACK does not currently support them.

Tests run and passed via:
python -m unittest executorch.backends.xnnpack.test.passes.test_batch_norm_fusion

Copy link

pytorch-bot bot commented Jun 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/11805

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 3 Unrelated Failures

As of commit 6c92d24 with merge base a8d7298 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 19, 2025
@keyprocedure
Copy link
Contributor Author

@pytorchbot label "release notes: none"

@pytorch-bot pytorch-bot bot added the release notes: none Do not include this in the release notes label Jun 19, 2025
@@ -64,6 +67,7 @@ def __init__(
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormWithConvPass,
FuseBatchNormWithLinearPass,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! Do you mind fusing these two passes? Something like BatchNormFusion pass? that way we can just generalize these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, combining the two passes cleaned up a lot of the duplication. Let me know if anything else needs to be changed.

Copy link
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM. I will let Max stamp it.

Copy link
Contributor

@mcr229 mcr229 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this thank you!

@keyprocedure
Copy link
Contributor Author

keyprocedure commented Jun 27, 2025

Awesome, I'm glad!

Should we support fusion for Linear with bias=False as well? I can add it to this PR.

@digantdesai
Copy link
Contributor

@keyprocedure let us know once this is good to go. Me or @mcr229 can merge this. Also you might need to rebase, rerun ci.

@keyprocedure keyprocedure force-pushed the support-linear-fused-batchnorm branch from 98650e6 to 20afaa8 Compare July 2, 2025 20:16
@keyprocedure
Copy link
Contributor Author

keyprocedure commented Jul 2, 2025

@digantdesai I added support for Linear with bias=False and rebased. Everything is good to go from my end - ready for CI. Thanks for reviewing and following up earlier!

@keyprocedure
Copy link
Contributor Author

The CI failures look unrelated to this PR. I noticed they were being discussed on Discord as known trunk issues. Just wanted to share in case it's helpful. Let me know if there's anything you'd like me to update.

Here's a summary of the failures:

  • test-eval_llama-mmlu-linux: RuntimeError: Dataset scripts are no longer supported, but found mmlu_no_train.py
  • test-arm-cortex-m-size-test (bare_metal): size check failed in cmake-out/test/size_test
  • unittest-release (linux & macos): TensorPtrMakerTest.FailedCreateTensorUsingFromBlobWithIllegalStrides

@digantdesai digantdesai merged commit f82c2f0 into pytorch:main Jul 11, 2025
193 of 199 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: none Do not include this in the release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Linear Fused Batchnorm
4 participants