-
Notifications
You must be signed in to change notification settings - Fork 649
[XNNPACK] Add support for Linear fused BatchNorm #11805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XNNPACK] Add support for Linear fused BatchNorm #11805
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/11805
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Unrelated FailuresAs of commit 6c92d24 with merge base a8d7298 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "release notes: none" |
backends/xnnpack/_passes/__init__.py
Outdated
@@ -64,6 +67,7 @@ def __init__( | |||
ConvertToSDPAPass, | |||
ConstPropPass, | |||
FuseBatchNormWithConvPass, | |||
FuseBatchNormWithLinearPass, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! Do you mind fusing these two passes? Something like BatchNormFusion pass? that way we can just generalize these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, combining the two passes cleaned up a lot of the duplication. Let me know if anything else needs to be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM. I will let Max stamp it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this thank you!
Awesome, I'm glad! Should we support fusion for Linear with |
@keyprocedure let us know once this is good to go. Me or @mcr229 can merge this. Also you might need to rebase, rerun ci. |
98650e6
to
20afaa8
Compare
@digantdesai I added support for Linear with bias=False and rebased. Everything is good to go from my end - ready for CI. Thanks for reviewing and following up earlier! |
The CI failures look unrelated to this PR. I noticed they were being discussed on Discord as known trunk issues. Just wanted to share in case it's helpful. Let me know if there's anything you'd like me to update. Here's a summary of the failures:
|
Summary
These changes implement a fusion pass in the XNNPACK partitioner to support linear + batchnorm operations. This pass involves identifying and combining each linear node that exclusively precedes a batchnorm node in the Export IR graph. Fusion occurs by updating the linear node's weight and bias with a new fused weight and bias computed from the linear and batchnorm parameters. In the case of linear nodes without bias, the fused bias is added as a new parameter to the linear node. All users of the batchnorm output are then redirected to the fused linear node, and the batchnorm node is removed from the graph.
This linear + batchnorm pass follows the existing implementation pattern of the convolution + batchnorm pass. These fusion passes fold batchnorm into the preceding convolution or linear ops during export. This allows the XNNPACK backend to run a single fused operation at inference, reducing memory usage and latency without affecting model accuracy.
Note: The current linear + batchnorm fusion implementation supports FP32 only. Quantized support is planned for a future PR to TorchAO.
Fixes #11587
Test plan
Tests were added to verify that linear + batchnorm fusion occurs for FP32 models when the linear layer has a single user, and is skipped for linear layers with multiple users. Both linear cases, with and without bias, are tested. A separate test ensures that standalone batchnorm layers are not partitioned, since XNNPACK does not currently support them.
Tests run and passed via:
python -m unittest executorch.backends.xnnpack.test.passes.test_batch_norm_fusion