From 02cb64e92d445b02141983b2dfae1106ed618c7f Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Sun, 15 Jun 2025 13:50:44 -0700 Subject: [PATCH 1/7] Add pass to fuse Linear and BatchNorm layers --- backends/xnnpack/_passes/__init__.py | 4 + .../_passes/fuse_batch_norm_with_linear.py | 227 ++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 backends/xnnpack/_passes/fuse_batch_norm_with_linear.py diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 4bf5bdfb079..7cf0ce8af79 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -24,6 +24,9 @@ from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( FuseBatchNormWithConvPass, ) +from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( + FuseBatchNormWithLinearPass, +) from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( TagImplicitQDqPass, @@ -64,6 +67,7 @@ def __init__( ConvertToSDPAPass, ConstPropPass, FuseBatchNormWithConvPass, + FuseBatchNormWithLinearPass, FuseActivationPass, DecomposeConcatenate, RemoveGetItemPass, diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py new file mode 100644 index 00000000000..7398b78653c --- /dev/null +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass + +from executorch.backends.xnnpack.utils.utils import ( + get_param_tensor, + get_tensor_name, + is_param_node, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind + +from torch.nn.utils.fusion import fuse_linear_bn_weights + + +class FuseBatchNormWithLinearPass(XNNPACKPass): + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + constant_placeholders_to_delete = set() + for linear in graph.nodes: + # We want to discover a chain of linear -> batch_norm. + # Only proceed if the current node is a linear node, and has a single + # user/successor. + if ( + linear.target != exir_ops.edge.aten.linear.default + and linear.target != exir_ops.edge.aten.addmm.default + or len(linear.users) != 1 + ): + continue + + # Single user of the linear op must be batch_norm + bn = list(linear.users.keys())[0] + if ( + bn.target != exir_ops.edge.aten.native_batch_norm.default + and bn.target + != exir_ops.edge.aten._native_batch_norm_legit_no_training.default + ): + continue + + # Get the parameters + assert len(linear.args) == 3 + + if linear.target == exir_ops.edge.aten.addmm.default: + # addmm.args = (bias, input, weight) + linear_bias_arg = linear.args[0] + linear_input_arg = linear.args[1] + linear_weight_arg = FuseBatchNormWithLinearPass.unwrap_to_param_node( + linear.args[2] + ) + else: + # linear.args = (input, weight, bias) + linear_input_arg = linear.args[0] + linear_weight_arg = linear.args[1] + linear_bias_arg = linear.args[2] + + if not self.can_fuse(linear, linear_weight_arg, bn, self.exported_program): + continue + + linear_weight = get_param_tensor(self.exported_program, linear_weight_arg) + linear_weight_name = get_tensor_name( + self.exported_program, linear_weight_arg + ) + assert linear_weight is not None + + linear_bias = get_param_tensor(self.exported_program, linear_bias_arg) + linear_bias_name = get_tensor_name(self.exported_program, linear_bias_arg) + + # Get the parameters from the batchnorm op + assert ( + bn.target == exir_ops.edge.aten.native_batch_norm.default + and len(bn.args) == 8 + ) or ( + bn.target + == exir_ops.edge.aten._native_batch_norm_legit_no_training.default + and len(bn.args) == 7 + ) + bn_weight = get_param_tensor(self.exported_program, bn.args[1]) + bn_bias = get_param_tensor(self.exported_program, bn.args[2]) + + running_mean = get_param_tensor(self.exported_program, bn.args[3]) + assert running_mean is not None + + running_var = get_param_tensor(self.exported_program, bn.args[4]) + assert running_var is not None + + # args[7] for native_batch_norm, but args[6] for + # _native_batch_norm_legit_no_training (which doesn't have training + # as an arg) + eps = bn.args[-1] + + fused_weight, fused_bias = fuse_linear_bn_weights( + linear_weight, + linear_bias, + running_mean, + running_var, + eps, + bn_weight, + bn_bias, + ) + + if linear.target == exir_ops.edge.aten.addmm.default: + # permute_copy node was removed, so weight must be transposed to (in × out) + fused_weight = fused_weight.t() + + fused_weight_name = (linear_weight_name + "_fused_bn").replace(".", "_") + if linear_bias_name == "": + fused_bias_name = (linear_weight_name + "_bias_fused_bn").replace( + ".", "_" + ) + else: + fused_bias_name = (linear_bias_name + "_fused_bn").replace(".", "_") + + # Modify the graph by updating the weight and bias of the linear op + # with the fused weight and bias params, and replacing all the users + # of getitem(batchnorm) with the linear op. + + with graph.inserting_before(linear_weight_arg): + fused_linear_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_weight_name, + data=fused_weight, + ) + if fused_bias is not None: + fused_linear_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_bias_name, + data=fused_bias, + ) + else: + fused_linear_bias_node = None + + if linear.target == exir_ops.edge.aten.addmm.default: + # addmm.args = (bias, input, weight) + linear.args = ( + fused_linear_bias_node, + linear_input_arg, + fused_linear_weight_node, + ) + else: + # linear.args = (input, weight, bias) + linear.args = ( + linear_input_arg, + fused_linear_weight_node, + fused_linear_bias_node, + ) + + # Remove any use of batchnorm from the graph + for user in bn.users.copy(): + assert user.target == operator.getitem + user.replace_all_uses_with(linear) + graph.erase_node(user) + + graph.erase_node(bn) + constant_placeholders_to_delete.update(linear.args[1:3] + bn.args[1:5]) + + if len(constant_placeholders_to_delete) > 0: + graph_module.graph.eliminate_dead_code() + for node in constant_placeholders_to_delete: + if (node is not None) and (len(node.users) == 0): + delete_constant_placeholder(self.exported_program, node) + + graph_module.recompile() + # To Regenerate metadata and shape information, retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) + + @staticmethod + def can_fuse( + linear: torch.fx.Node, + linear_weights: torch.fx.Node, + bn: torch.fx.Node, + program: ExportedProgram, + ) -> bool: + """ + Determine whether a batch norm node can be fused with a preceding linear node. + """ + + # All the users of the batchnorm node must be getitem ops. batchnorm + # returns a 3-element tuple. Each user must only access the first + # element of the tuple. + if [ + (user.target == operator.getitem and user.args[1] == 0) for user in bn.users + ].count(False): + return False + + bn_weights = bn.args[1] + + # Check that the weights for linear and batchnorm are both params + if not isinstance(linear_weights, torch.fx.Node) or not isinstance( + bn_weights, torch.fx.Node + ): + return False + + if [ + is_param_node(program, node) for node in {linear_weights, bn_weights} + ].count(False): + return False + + return True + + @staticmethod + def unwrap_to_param_node(node: torch.fx.Node) -> torch.fx.Node: + while node.op == "call_function" and node.target in { + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + }: + node = node.args[0] + return node From 49343671513cf464df018cf328e60368dd0f827e Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Sun, 15 Jun 2025 13:51:46 -0700 Subject: [PATCH 2/7] Add tests for Linear and BatchNorm fusion --- .../passes/test_batch_norm_linear_fusion.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py diff --git a/backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py new file mode 100644 index 00000000000..677280ce52f --- /dev/null +++ b/backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( + FuseBatchNormWithLinearPass, +) +from executorch.backends.xnnpack.test.tester import RunPasses, Tester + + +class TestBatchNormLinearFusion(unittest.TestCase): + PassStage = RunPasses([FuseBatchNormWithLinearPass]) + bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + + def setUp(self): + torch._dynamo.reset() + + class ModelLinearBN(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + op = torch.nn.Linear + self.linear = op(in_features, out_features) + self.bn = torch.nn.BatchNorm1d(out_features) + self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats + + def forward(self, x): + y = self.linear(x) + y = self.bn(y) + y = self.linear(y) + y = y + y + return self.bn(y) + + def test_fp32_batch_norm_fusion(self): + ( + Tester( + self.ModelLinearBN(2, 2).eval(), + (torch.randn(2, 2),), + ) + .export() + .to_edge() + .run_passes(self.PassStage) + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() + ) + + def test_q8_batch_norm_fusion(self): + ( + Tester( + self.ModelLinearBN(2, 2).eval(), + (torch.randn(2, 2),), + ) + .quantize() + .export() + .to_edge() + .run_passes(self.PassStage) + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() + ) + + def test_fp32_batch_norm_no_fusion_doesnt_partition(self): + """ + We do not currently support standalone batch norms (i.e. batch norms that are + not fused with a linear). This is planned, but until implemented, this test ensures + that we do not partition the standalone batch norm and then fail to lower. + """ + + class BN(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(2) + + def forward(self, x): + return self.bn(x) + + ( + Tester(BN(), (torch.randn(2, 2),)) + .export() + .to_edge() + .check_count({self.bn_name: 1}) + .partition() + .check_count({self.bn_name: 1}) + ) From 2aefc750984ae05520af3bbb47f4d703159e2f2d Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Mon, 16 Jun 2025 15:08:39 -0700 Subject: [PATCH 3/7] Add comments --- .../_passes/fuse_batch_norm_with_linear.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py index 7398b78653c..867a95cdde1 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py @@ -32,8 +32,8 @@ def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph constant_placeholders_to_delete = set() for linear in graph.nodes: - # We want to discover a chain of linear -> batch_norm. - # Only proceed if the current node is a linear node, and has a single + # We want to discover a chain of linear -> batch_norm or addmm -> batch_norm. + # Only proceed if the current node is a linear or addmm node, and has a single # user/successor. if ( linear.target != exir_ops.edge.aten.linear.default @@ -58,7 +58,8 @@ def call(self, graph_module: torch.fx.GraphModule): # addmm.args = (bias, input, weight) linear_bias_arg = linear.args[0] linear_input_arg = linear.args[1] - linear_weight_arg = FuseBatchNormWithLinearPass.unwrap_to_param_node( + # Unwrap permute_copy to access weight parameter node + linear_weight_arg = FuseBatchNormWithLinearPass._unwrap_node( linear.args[2] ) else: @@ -67,7 +68,7 @@ def call(self, graph_module: torch.fx.GraphModule): linear_weight_arg = linear.args[1] linear_bias_arg = linear.args[2] - if not self.can_fuse(linear, linear_weight_arg, bn, self.exported_program): + if not self.can_fuse(linear_weight_arg, bn, self.exported_program): continue linear_weight = get_param_tensor(self.exported_program, linear_weight_arg) @@ -113,7 +114,8 @@ def call(self, graph_module: torch.fx.GraphModule): ) if linear.target == exir_ops.edge.aten.addmm.default: - # permute_copy node was removed, so weight must be transposed to (in × out) + # fuse_linear_bn_weights returns weight [out, in]; + # permute_copy node was removed, so weight must be transposed to [in, out] for addmm fused_weight = fused_weight.t() fused_weight_name = (linear_weight_name + "_fused_bn").replace(".", "_") @@ -185,7 +187,6 @@ def call(self, graph_module: torch.fx.GraphModule): @staticmethod def can_fuse( - linear: torch.fx.Node, linear_weights: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram, @@ -218,7 +219,7 @@ def can_fuse( return True @staticmethod - def unwrap_to_param_node(node: torch.fx.Node) -> torch.fx.Node: + def _unwrap_node(node: torch.fx.Node) -> torch.fx.Node: while node.op == "call_function" and node.target in { exir_ops.edge.aten.permute.default, exir_ops.edge.aten.permute_copy.default, From 98a954c37c4ab95cb210e5009befd28dc8f1de72 Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 18 Jun 2025 14:19:52 -0700 Subject: [PATCH 4/7] Remove addmm node check from linear pass and combine conv/linear fusion test files --- .../_passes/fuse_batch_norm_with_linear.py | 80 ++++------------- .../test/passes/test_batch_norm_fusion.py | 67 ++++++++++++-- .../passes/test_batch_norm_linear_fusion.py | 87 ------------------- 3 files changed, 79 insertions(+), 155 deletions(-) delete mode 100644 backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py index 867a95cdde1..9ccf84105bf 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py @@ -32,12 +32,11 @@ def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph constant_placeholders_to_delete = set() for linear in graph.nodes: - # We want to discover a chain of linear -> batch_norm or addmm -> batch_norm. - # Only proceed if the current node is a linear or addmm node, and has a single + # We want to discover a chain of linear -> batch_norm. + # Only proceed if the current node is a linear node, and has a single # user/successor. if ( linear.target != exir_ops.edge.aten.linear.default - and linear.target != exir_ops.edge.aten.addmm.default or len(linear.users) != 1 ): continue @@ -51,34 +50,18 @@ def call(self, graph_module: torch.fx.GraphModule): ): continue + if not self.can_fuse(linear, bn, self.exported_program): + continue + # Get the parameters assert len(linear.args) == 3 - if linear.target == exir_ops.edge.aten.addmm.default: - # addmm.args = (bias, input, weight) - linear_bias_arg = linear.args[0] - linear_input_arg = linear.args[1] - # Unwrap permute_copy to access weight parameter node - linear_weight_arg = FuseBatchNormWithLinearPass._unwrap_node( - linear.args[2] - ) - else: - # linear.args = (input, weight, bias) - linear_input_arg = linear.args[0] - linear_weight_arg = linear.args[1] - linear_bias_arg = linear.args[2] - - if not self.can_fuse(linear_weight_arg, bn, self.exported_program): - continue - - linear_weight = get_param_tensor(self.exported_program, linear_weight_arg) - linear_weight_name = get_tensor_name( - self.exported_program, linear_weight_arg - ) + linear_weight = get_param_tensor(self.exported_program, linear.args[1]) + linear_weight_name = get_tensor_name(self.exported_program, linear.args[1]) assert linear_weight is not None - linear_bias = get_param_tensor(self.exported_program, linear_bias_arg) - linear_bias_name = get_tensor_name(self.exported_program, linear_bias_arg) + linear_bias = get_param_tensor(self.exported_program, linear.args[2]) + linear_bias_name = get_tensor_name(self.exported_program, linear.args[2]) # Get the parameters from the batchnorm op assert ( @@ -112,12 +95,6 @@ def call(self, graph_module: torch.fx.GraphModule): bn_weight, bn_bias, ) - - if linear.target == exir_ops.edge.aten.addmm.default: - # fuse_linear_bn_weights returns weight [out, in]; - # permute_copy node was removed, so weight must be transposed to [in, out] for addmm - fused_weight = fused_weight.t() - fused_weight_name = (linear_weight_name + "_fused_bn").replace(".", "_") if linear_bias_name == "": fused_bias_name = (linear_weight_name + "_bias_fused_bn").replace( @@ -130,7 +107,7 @@ def call(self, graph_module: torch.fx.GraphModule): # with the fused weight and bias params, and replacing all the users # of getitem(batchnorm) with the linear op. - with graph.inserting_before(linear_weight_arg): + with graph.inserting_before(linear.args[1]): fused_linear_weight_node = create_constant_placeholder( exp_program=self.exported_program, graph=graph_module.graph, @@ -149,20 +126,11 @@ def call(self, graph_module: torch.fx.GraphModule): else: fused_linear_bias_node = None - if linear.target == exir_ops.edge.aten.addmm.default: - # addmm.args = (bias, input, weight) - linear.args = ( - fused_linear_bias_node, - linear_input_arg, - fused_linear_weight_node, - ) - else: - # linear.args = (input, weight, bias) - linear.args = ( - linear_input_arg, - fused_linear_weight_node, - fused_linear_bias_node, - ) + linear.args = ( + linear.args[0], + fused_linear_weight_node, + fused_linear_bias_node, + ) # Remove any use of batchnorm from the graph for user in bn.users.copy(): @@ -187,7 +155,7 @@ def call(self, graph_module: torch.fx.GraphModule): @staticmethod def can_fuse( - linear_weights: torch.fx.Node, + linear: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram, ) -> bool: @@ -206,23 +174,11 @@ def can_fuse( bn_weights = bn.args[1] # Check that the weights for linear and batchnorm are both params - if not isinstance(linear_weights, torch.fx.Node) or not isinstance( + if not isinstance(linear, torch.fx.Node) or not isinstance( bn_weights, torch.fx.Node ): return False - if [ - is_param_node(program, node) for node in {linear_weights, bn_weights} - ].count(False): + if [is_param_node(program, node) for node in {linear, bn_weights}].count(False): return False - return True - - @staticmethod - def _unwrap_node(node: torch.fx.Node) -> torch.fx.Node: - while node.op == "call_function" and node.target in { - exir_ops.edge.aten.permute.default, - exir_ops.edge.aten.permute_copy.default, - }: - node = node.args[0] - return node diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index 70c93c3751b..d380ba02961 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -11,11 +11,15 @@ from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( FuseBatchNormWithConvPass, ) +from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( + FuseBatchNormWithLinearPass, +) from executorch.backends.xnnpack.test.tester import RunPasses, Tester class TestBatchNormFusion(unittest.TestCase): - PassStage = RunPasses([FuseBatchNormWithConvPass]) + ConvPassStage = RunPasses([FuseBatchNormWithConvPass]) + LinearPassStage = RunPasses([FuseBatchNormWithLinearPass]) bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" def setUp(self): @@ -42,7 +46,22 @@ def forward(self, x): y = y + y return self.bn(y) - def test_fp32_batch_norm_fusion(self): + class ModelLinearBN(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + op = torch.nn.Linear + self.linear = op(in_features, out_features) + self.bn = torch.nn.BatchNorm1d(out_features) + self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats + + def forward(self, x): + y = self.linear(x) + y = self.bn(y) + y = self.linear(y) + y = y + y + return self.bn(y) + + def test_fp32_conv_batch_norm_fusion(self): for transpose in [False, True]: ( Tester( @@ -51,12 +70,12 @@ def test_fp32_batch_norm_fusion(self): ) .export() .to_edge() - .run_passes(self.PassStage) + .run_passes(self.ConvPassStage) .check_count({self.bn_name: 1}) .run_method_and_compare_outputs() ) - def test_q8_batch_norm_fusion(self): + def test_q8_conv_batch_norm_fusion(self): for transpose in [False, True]: ( Tester( @@ -66,12 +85,12 @@ def test_q8_batch_norm_fusion(self): .quantize() .export() .to_edge() - .run_passes(self.PassStage) + .run_passes(self.ConvPassStage) .check_count({self.bn_name: 1}) .run_method_and_compare_outputs() ) - def test_fp32_batch_norm_no_fusion_doesnt_partition(self): + def test_fp32_conv_batch_norm_no_fusion_doesnt_partition(self): """ We do not currently support standalone batch norms (i.e. batch norms that are not fused with a conv). This is planned, but until implemented, this test ensures @@ -94,3 +113,39 @@ def forward(self, x): .partition() .check_count({self.bn_name: 1}) ) + + def test_fp32_linear_batch_norm_fusion(self): + ( + Tester( + self.ModelLinearBN(2, 2).eval(), + (torch.randn(2, 2),), + ) + .export() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() + ) + + # def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self): + # """ + # We do not currently support standalone batch norms (i.e. batch norms that are + # not fused with a linear). This is planned, but until implemented, this test ensures + # that we do not partition the standalone batch norm and then fail to lower. + # """ + # + # class BN(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.bn = torch.nn.BatchNorm1d(2) + # + # def forward(self, x): + # return self.bn(x) + # + # ( + # Tester(BN(), (torch.randn(2, 2),)) + # .export() + # .to_edge() + # .check_count({self.bn_name: 1}) + # .partition() + # .check_count({self.bn_name: 1}) + # ) diff --git a/backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py deleted file mode 100644 index 677280ce52f..00000000000 --- a/backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( - FuseBatchNormWithLinearPass, -) -from executorch.backends.xnnpack.test.tester import RunPasses, Tester - - -class TestBatchNormLinearFusion(unittest.TestCase): - PassStage = RunPasses([FuseBatchNormWithLinearPass]) - bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" - - def setUp(self): - torch._dynamo.reset() - - class ModelLinearBN(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - op = torch.nn.Linear - self.linear = op(in_features, out_features) - self.bn = torch.nn.BatchNorm1d(out_features) - self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats - - def forward(self, x): - y = self.linear(x) - y = self.bn(y) - y = self.linear(y) - y = y + y - return self.bn(y) - - def test_fp32_batch_norm_fusion(self): - ( - Tester( - self.ModelLinearBN(2, 2).eval(), - (torch.randn(2, 2),), - ) - .export() - .to_edge() - .run_passes(self.PassStage) - .check_count({self.bn_name: 1}) - .run_method_and_compare_outputs() - ) - - def test_q8_batch_norm_fusion(self): - ( - Tester( - self.ModelLinearBN(2, 2).eval(), - (torch.randn(2, 2),), - ) - .quantize() - .export() - .to_edge() - .run_passes(self.PassStage) - .check_count({self.bn_name: 1}) - .run_method_and_compare_outputs() - ) - - def test_fp32_batch_norm_no_fusion_doesnt_partition(self): - """ - We do not currently support standalone batch norms (i.e. batch norms that are - not fused with a linear). This is planned, but until implemented, this test ensures - that we do not partition the standalone batch norm and then fail to lower. - """ - - class BN(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm1d(2) - - def forward(self, x): - return self.bn(x) - - ( - Tester(BN(), (torch.randn(2, 2),)) - .export() - .to_edge() - .check_count({self.bn_name: 1}) - .partition() - .check_count({self.bn_name: 1}) - ) From 8003ba35f03f4850fdf3cf070b97cf601e5b4683 Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 18 Jun 2025 17:41:09 -0700 Subject: [PATCH 5/7] Register Linear fusion pass in BatchNormConfig --- .../_passes/fuse_batch_norm_with_linear.py | 11 +++-- .../xnnpack/partition/config/node_configs.py | 19 +++++--- .../test/passes/test_batch_norm_fusion.py | 44 +++++++++---------- 3 files changed, 40 insertions(+), 34 deletions(-) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py index 9ccf84105bf..c3810f22fbd 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py @@ -41,7 +41,7 @@ def call(self, graph_module: torch.fx.GraphModule): ): continue - # Single user of the linear op must be batch_norm + # Single user of the linear op must be batch_norm. If not, bail. bn = list(linear.users.keys())[0] if ( bn.target != exir_ops.edge.aten.native_batch_norm.default @@ -53,7 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule): if not self.can_fuse(linear, bn, self.exported_program): continue - # Get the parameters + # Get the parameters from linear op assert len(linear.args) == 3 linear_weight = get_param_tensor(self.exported_program, linear.args[1]) @@ -171,14 +171,17 @@ def can_fuse( ].count(False): return False + linear_weights = linear.args[1] bn_weights = bn.args[1] # Check that the weights for linear and batchnorm are both params - if not isinstance(linear, torch.fx.Node) or not isinstance( + if not isinstance(linear_weights, torch.fx.Node) or not isinstance( bn_weights, torch.fx.Node ): return False - if [is_param_node(program, node) for node in {linear, bn_weights}].count(False): + if [ + is_param_node(program, node) for node in {linear_weights, bn_weights} + ].count(False): return False return True diff --git a/backends/xnnpack/partition/config/node_configs.py b/backends/xnnpack/partition/config/node_configs.py index 23acfbfb8c4..1bfc0f80eb8 100644 --- a/backends/xnnpack/partition/config/node_configs.py +++ b/backends/xnnpack/partition/config/node_configs.py @@ -12,6 +12,9 @@ from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( FuseBatchNormWithConvPass, ) +from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( + FuseBatchNormWithLinearPass, +) from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, XNNPartitionerConfig, @@ -35,20 +38,22 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False bn = node - conv = node.all_input_nodes[0] + input_node = node.all_input_nodes[0] - if conv.op != "call_function": + if input_node.op != "call_function": return False - conv_name = format_target_name(conv.target.__name__) # pyre-ignore + input_name = format_target_name(input_node.target.__name__) # pyre-ignore - if conv_name not in ["convolution.default"]: - why(node, f"Invalid conv target {conv_name}") + if input_name not in ["convolution.default", "linear.default"]: + why(node, f"Invalid input target {input_name.split('.')[0]}") return False - can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep) + can_fuse = FuseBatchNormWithConvPass.can_fuse( + input_node, bn, ep + ) or FuseBatchNormWithLinearPass.can_fuse(input_node, bn, ep) if not can_fuse: - why(node, "BatchNorm cannot be fused with Convolution") + why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}") return False return True diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index d380ba02961..b8362b2ebfd 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -126,26 +126,24 @@ def test_fp32_linear_batch_norm_fusion(self): .run_method_and_compare_outputs() ) - # def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self): - # """ - # We do not currently support standalone batch norms (i.e. batch norms that are - # not fused with a linear). This is planned, but until implemented, this test ensures - # that we do not partition the standalone batch norm and then fail to lower. - # """ - # - # class BN(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.bn = torch.nn.BatchNorm1d(2) - # - # def forward(self, x): - # return self.bn(x) - # - # ( - # Tester(BN(), (torch.randn(2, 2),)) - # .export() - # .to_edge() - # .check_count({self.bn_name: 1}) - # .partition() - # .check_count({self.bn_name: 1}) - # ) + def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self): + """ + We do not currently support standalone batch norms (i.e. batch norms that are + not fused with a linear). This is planned, but until implemented, this test ensures + that we do not partition the standalone batch norm and then fail to lower. + """ + + class BN(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(2) + + def forward(self, x): + return self.bn(x) + + ( + Tester(BN(), (torch.randn(2, 2),)) + .export() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 1}) + ) From 9bd8064c3b536c710ca99c83fdf824e156e05b40 Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 25 Jun 2025 17:40:02 -0700 Subject: [PATCH 6/7] Merge conv and linear fusion passes into FuseBatchNormPass --- backends/xnnpack/_passes/__init__.py | 10 +- backends/xnnpack/_passes/fuse_batch_norm.py | 232 ++++++++++++++++++ .../_passes/fuse_batch_norm_with_conv.py | 197 --------------- .../_passes/fuse_batch_norm_with_linear.py | 187 -------------- .../xnnpack/partition/config/node_configs.py | 11 +- .../test/passes/test_batch_norm_fusion.py | 14 +- 6 files changed, 240 insertions(+), 411 deletions(-) create mode 100644 backends/xnnpack/_passes/fuse_batch_norm.py delete mode 100644 backends/xnnpack/_passes/fuse_batch_norm_with_conv.py delete mode 100644 backends/xnnpack/_passes/fuse_batch_norm_with_linear.py diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 7cf0ce8af79..5d8526f326a 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -21,12 +21,7 @@ ) from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( - FuseBatchNormWithConvPass, -) -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( - FuseBatchNormWithLinearPass, -) +from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( TagImplicitQDqPass, @@ -66,8 +61,7 @@ def __init__( ConvertToLinearPass, ConvertToSDPAPass, ConstPropPass, - FuseBatchNormWithConvPass, - FuseBatchNormWithLinearPass, + FuseBatchNormPass, FuseActivationPass, DecomposeConcatenate, RemoveGetItemPass, diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py new file mode 100644 index 00000000000..88a912670a5 --- /dev/null +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass + +from executorch.backends.xnnpack.utils.utils import ( + get_param_tensor, + get_tensor_name, + is_param_node, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind + +from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights + + +class FuseBatchNormPass(XNNPACKPass): + """ + BatchNorm can be implemented using 1x1 Depthwise Convolution. However, doing so will increase + memory usage since we serialize new weights to represent the convolution. In most cases, + BatchNorm is used after convolution or linear. The 1x1 depthwise convolution can then be fused + with the previous convolution. For linear cases, BatchNorm can be folded into the previous linear layer. + """ + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + constant_placeholders_to_delete = set() + for node in graph.nodes: + # We want to discover a chain of conv -> batch_norm or linear -> batch_norm. + # Only proceed if the current node is a conv or linear node, and has a single + # user/successor. + is_conv = node.target == exir_ops.edge.aten.convolution.default + is_linear = node.target == exir_ops.edge.aten.linear.default + + if not (is_conv or is_linear): + continue + if len(node.users) != 1: + continue + + # Conv or linear op to fuse. + target_op = node + + # The single user of the op must be batch_norm. If not, bail. + bn = list(target_op.users.keys())[0] + if ( + bn.target != exir_ops.edge.aten.native_batch_norm.default + and bn.target + != exir_ops.edge.aten._native_batch_norm_legit_no_training.default + ): + continue + + if not self.can_fuse(target_op, bn, self.exported_program): + continue + + self._fuse_ops( + graph_module, + graph, + target_op, + bn, + is_conv, + constant_placeholders_to_delete, + ) + + if len(constant_placeholders_to_delete) > 0: + graph_module.graph.eliminate_dead_code() + for node in constant_placeholders_to_delete: + if (node is not None) and (len(node.users) == 0): + delete_constant_placeholder(self.exported_program, node) + + graph_module.recompile() + # To Regenerate metadata and shape information, retrace module. + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) + + @staticmethod + def can_fuse( + target_op: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram + ) -> bool: + """ + Determine whether a batchnorm node can be fused with a preceding conv or linear node. + """ + + # All the users of batchnorm node must be getitem ops. batchnorm + # returns a 3-element tuple. Each user must only access the first + # element of the tuple. + if [ + (user.target == operator.getitem and user.args[1] == 0) for user in bn.users + ].count(False): + return False + + target_op_weights = target_op.args[1] + bn_weights = bn.args[1] + + # Check that the weights for conv or linear and batchnorm are both params. + if not isinstance(target_op_weights, torch.fx.Node) or not isinstance( + bn_weights, torch.fx.Node + ): + return False + + if [ + is_param_node(program, node) for node in {target_op_weights, bn_weights} + ].count(False): + return False + + return True + + def _fuse_ops( + self, + graph_module: torch.fx.GraphModule, + graph: torch.fx.Graph, + target_op: torch.fx.Node, + bn: torch.fx.Node, + is_conv: bool, + constant_placeholders_to_delete: set, + ) -> None: + """ + Fuse a BatchNorm into the preceding conv or linear op. + Update the fused op's weight and bias, rewire users of the BatchNorm's output, and remove the BatchNorm node. + """ + + if is_conv: + assert len(target_op.args) == 9 + else: # Linear path: (input, weight, bias). + assert len(target_op.args) == 3 + + # Get the weight and bias parameters from the conv or linear op. + target_op_weight = get_param_tensor(self.exported_program, target_op.args[1]) + target_op_weight_name = get_tensor_name( + self.exported_program, target_op.args[1] + ) + assert target_op_weight is not None + + target_op_bias = get_param_tensor(self.exported_program, target_op.args[2]) + target_op_bias_name = get_tensor_name(self.exported_program, target_op.args[2]) + + # Get the parameters from the batchnorm op. + assert ( + bn.target == exir_ops.edge.aten.native_batch_norm.default + and len(bn.args) == 8 + ) or ( + bn.target == exir_ops.edge.aten._native_batch_norm_legit_no_training.default + and len(bn.args) == 7 + ) + bn_weight = get_param_tensor(self.exported_program, bn.args[1]) + bn_bias = get_param_tensor(self.exported_program, bn.args[2]) + + running_mean = get_param_tensor(self.exported_program, bn.args[3]) + assert running_mean is not None + + running_var = get_param_tensor(self.exported_program, bn.args[4]) + assert running_var is not None + + # args[7] for native_batch_norm, but args[6] for + # _native_batch_norm_legit_no_training (which doesn't have training + # as an arg). + eps = bn.args[-1] + + # Compute the updated weight and bias after fusing conv or linear op with batchnorm op. + fuse_args = ( + target_op_weight, + target_op_bias, + running_mean, + running_var, + eps, + bn_weight, + bn_bias, + ) + + if is_conv: + is_transpose = target_op.args[6] + fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose) + else: # Linear path. + fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args) + + fused_weight_name = (target_op_weight_name + "_fused_bn").replace(".", "_") + if target_op_bias_name == "": + fused_bias_name = (target_op_weight_name + "_bias_fused_bn").replace( + ".", "_" + ) + else: + fused_bias_name = (target_op_bias_name + "_fused_bn").replace(".", "_") + + # Modify the graph by updating the weight and bias of conv or linear op + # with the fused weight and bias params, and replacing all the users + # of getitem(batchnorm) with the conv or linear op. + with graph.inserting_before(target_op.args[1]): + fused_op_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_weight_name, + data=fused_weight, + ) + if fused_bias is not None: + fused_op_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_bias_name, + data=fused_bias, + ) + else: + fused_op_bias_node = None + + # Replace weight and bias with the fused batchnorm values. + args = list(target_op.args) + args[1] = fused_op_weight_node + args[2] = fused_op_bias_node + target_op.args = tuple(args) + + # Remove any use of batchnorm from the graph + for user in bn.users.copy(): + assert user.target == operator.getitem + user.replace_all_uses_with(target_op) + graph.erase_node(user) + + graph.erase_node(bn) + constant_placeholders_to_delete.update(target_op.args[1:3] + bn.args[1:5]) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py deleted file mode 100644 index 6f31fe698ba..00000000000 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import operator - -import torch -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - delete_constant_placeholder, -) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass - -from executorch.backends.xnnpack.utils.utils import ( - get_param_tensor, - get_tensor_name, - is_param_node, -) -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult -from torch.export.graph_signature import InputKind - -from torch.nn.utils.fusion import fuse_conv_bn_weights - - -class FuseBatchNormWithConvPass(XNNPACKPass): - """ - Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase - memory usage since we serialize new weights to represent the convolution. In most cases, - Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused - with the previous convolution - """ - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - constant_placeholders_to_delete = set() - for conv in graph.nodes: - # We want to discover a chain of conv -> batch_norm. - # Only proceed if the current node is a conv node, and has a single - # user/successor. - if ( - conv.target != exir_ops.edge.aten.convolution.default - or len(conv.users) != 1 - ): - continue - - # The single user of conv op must be batch_norm. If not, bail. - bn = list(conv.users.keys())[0] - if ( - bn.target != exir_ops.edge.aten.native_batch_norm.default - and bn.target - != exir_ops.edge.aten._native_batch_norm_legit_no_training.default - ): - continue - - if not self.can_fuse(conv, bn, self.exported_program): - continue - - # Get the parameters from conv op - assert len(conv.args) == 9 - - conv_weight = get_param_tensor(self.exported_program, conv.args[1]) - conv_weight_name = get_tensor_name(self.exported_program, conv.args[1]) - assert conv_weight is not None - - conv_bias = get_param_tensor(self.exported_program, conv.args[2]) - conv_bias_name = get_tensor_name(self.exported_program, conv.args[2]) - - # Get the parameters from the batchnorm op - assert ( - bn.target == exir_ops.edge.aten.native_batch_norm.default - and len(bn.args) == 8 - ) or ( - bn.target - == exir_ops.edge.aten._native_batch_norm_legit_no_training.default - and len(bn.args) == 7 - ) - bn_weight = get_param_tensor(self.exported_program, bn.args[1]) - bn_bias = get_param_tensor(self.exported_program, bn.args[2]) - - running_mean = get_param_tensor(self.exported_program, bn.args[3]) - assert running_mean is not None - - running_var = get_param_tensor(self.exported_program, bn.args[4]) - assert running_var is not None - - # args[7] for native_batch_norm, but args[6] for - # _native_batch_norm_legit_no_training (which doesn't have training - # as an arg) - eps = bn.args[-1] - - is_transpose = conv.args[6] - # Compute the updated weight and bias after fusing conv op - # with batchnorm op. - fused_weight, fused_bias = fuse_conv_bn_weights( - conv_weight, - conv_bias, - running_mean, - running_var, - eps, - bn_weight, - bn_bias, - is_transpose, - ) - fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_") - if conv_bias_name == "": - fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace( - ".", "_" - ) - else: - fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_") - - # Modify the graph by updating the weight and bias of conv op - # with the fused weight and bias params, and replacing all the users - # of getitem(batchnorm) with the conv op. - with graph.inserting_before(conv.args[1]): - fused_conv_weight_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=fused_weight_name, - data=fused_weight, - ) - if fused_bias is not None: - fused_conv_bias_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=fused_bias_name, - data=fused_bias, - ) - else: - fused_conv_bias_node = None - - conv.args = ( - conv.args[0], - fused_conv_weight_node, - fused_conv_bias_node, - *conv.args[3:], - ) - - # Remove any use of batchnorm from the graph - for user in bn.users.copy(): - assert user.target == operator.getitem - user.replace_all_uses_with(conv) - graph.erase_node(user) - - graph.erase_node(bn) - constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5]) - - if len(constant_placeholders_to_delete) > 0: - graph_module.graph.eliminate_dead_code() - for node in constant_placeholders_to_delete: - if (node is not None) and (len(node.users) == 0): - delete_constant_placeholder(self.exported_program, node) - - graph_module.recompile() - # To Regenerate meta data and shape information, retrace module - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) - - @staticmethod - def can_fuse( - conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram - ) -> bool: - """ - Determine whether a batch norm node can be fused with a preceding conv node. - """ - - # All the users of batchnorm node must be getitem ops. batchnorm - # returns a 3-element tuple. Each user must only access the first - # element of the tuple. - if [ - (user.target == operator.getitem and user.args[1] == 0) for user in bn.users - ].count(False): - return False - - conv_weights = conv.args[1] - bn_weights = bn.args[1] - - # Check that the weights for conv and batchnorm are both params - if not isinstance(conv_weights, torch.fx.Node) or not isinstance( - bn_weights, torch.fx.Node - ): - return False - - if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count( - False - ): - return False - - return True diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py b/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py deleted file mode 100644 index c3810f22fbd..00000000000 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_linear.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import operator - -import torch -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - delete_constant_placeholder, -) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass - -from executorch.backends.xnnpack.utils.utils import ( - get_param_tensor, - get_tensor_name, - is_param_node, -) -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult -from torch.export.graph_signature import InputKind - -from torch.nn.utils.fusion import fuse_linear_bn_weights - - -class FuseBatchNormWithLinearPass(XNNPACKPass): - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - constant_placeholders_to_delete = set() - for linear in graph.nodes: - # We want to discover a chain of linear -> batch_norm. - # Only proceed if the current node is a linear node, and has a single - # user/successor. - if ( - linear.target != exir_ops.edge.aten.linear.default - or len(linear.users) != 1 - ): - continue - - # Single user of the linear op must be batch_norm. If not, bail. - bn = list(linear.users.keys())[0] - if ( - bn.target != exir_ops.edge.aten.native_batch_norm.default - and bn.target - != exir_ops.edge.aten._native_batch_norm_legit_no_training.default - ): - continue - - if not self.can_fuse(linear, bn, self.exported_program): - continue - - # Get the parameters from linear op - assert len(linear.args) == 3 - - linear_weight = get_param_tensor(self.exported_program, linear.args[1]) - linear_weight_name = get_tensor_name(self.exported_program, linear.args[1]) - assert linear_weight is not None - - linear_bias = get_param_tensor(self.exported_program, linear.args[2]) - linear_bias_name = get_tensor_name(self.exported_program, linear.args[2]) - - # Get the parameters from the batchnorm op - assert ( - bn.target == exir_ops.edge.aten.native_batch_norm.default - and len(bn.args) == 8 - ) or ( - bn.target - == exir_ops.edge.aten._native_batch_norm_legit_no_training.default - and len(bn.args) == 7 - ) - bn_weight = get_param_tensor(self.exported_program, bn.args[1]) - bn_bias = get_param_tensor(self.exported_program, bn.args[2]) - - running_mean = get_param_tensor(self.exported_program, bn.args[3]) - assert running_mean is not None - - running_var = get_param_tensor(self.exported_program, bn.args[4]) - assert running_var is not None - - # args[7] for native_batch_norm, but args[6] for - # _native_batch_norm_legit_no_training (which doesn't have training - # as an arg) - eps = bn.args[-1] - - fused_weight, fused_bias = fuse_linear_bn_weights( - linear_weight, - linear_bias, - running_mean, - running_var, - eps, - bn_weight, - bn_bias, - ) - fused_weight_name = (linear_weight_name + "_fused_bn").replace(".", "_") - if linear_bias_name == "": - fused_bias_name = (linear_weight_name + "_bias_fused_bn").replace( - ".", "_" - ) - else: - fused_bias_name = (linear_bias_name + "_fused_bn").replace(".", "_") - - # Modify the graph by updating the weight and bias of the linear op - # with the fused weight and bias params, and replacing all the users - # of getitem(batchnorm) with the linear op. - - with graph.inserting_before(linear.args[1]): - fused_linear_weight_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=fused_weight_name, - data=fused_weight, - ) - if fused_bias is not None: - fused_linear_bias_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=fused_bias_name, - data=fused_bias, - ) - else: - fused_linear_bias_node = None - - linear.args = ( - linear.args[0], - fused_linear_weight_node, - fused_linear_bias_node, - ) - - # Remove any use of batchnorm from the graph - for user in bn.users.copy(): - assert user.target == operator.getitem - user.replace_all_uses_with(linear) - graph.erase_node(user) - - graph.erase_node(bn) - constant_placeholders_to_delete.update(linear.args[1:3] + bn.args[1:5]) - - if len(constant_placeholders_to_delete) > 0: - graph_module.graph.eliminate_dead_code() - for node in constant_placeholders_to_delete: - if (node is not None) and (len(node.users) == 0): - delete_constant_placeholder(self.exported_program, node) - - graph_module.recompile() - # To Regenerate metadata and shape information, retrace module - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) - - @staticmethod - def can_fuse( - linear: torch.fx.Node, - bn: torch.fx.Node, - program: ExportedProgram, - ) -> bool: - """ - Determine whether a batch norm node can be fused with a preceding linear node. - """ - - # All the users of the batchnorm node must be getitem ops. batchnorm - # returns a 3-element tuple. Each user must only access the first - # element of the tuple. - if [ - (user.target == operator.getitem and user.args[1] == 0) for user in bn.users - ].count(False): - return False - - linear_weights = linear.args[1] - bn_weights = bn.args[1] - - # Check that the weights for linear and batchnorm are both params - if not isinstance(linear_weights, torch.fx.Node) or not isinstance( - bn_weights, torch.fx.Node - ): - return False - - if [ - is_param_node(program, node) for node in {linear_weights, bn_weights} - ].count(False): - return False - return True diff --git a/backends/xnnpack/partition/config/node_configs.py b/backends/xnnpack/partition/config/node_configs.py index 1bfc0f80eb8..4659ea05a0f 100644 --- a/backends/xnnpack/partition/config/node_configs.py +++ b/backends/xnnpack/partition/config/node_configs.py @@ -9,12 +9,7 @@ from typing import List, Optional import torch -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( - FuseBatchNormWithConvPass, -) -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( - FuseBatchNormWithLinearPass, -) +from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, XNNPartitionerConfig, @@ -49,9 +44,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: why(node, f"Invalid input target {input_name.split('.')[0]}") return False - can_fuse = FuseBatchNormWithConvPass.can_fuse( - input_node, bn, ep - ) or FuseBatchNormWithLinearPass.can_fuse(input_node, bn, ep) + can_fuse = FuseBatchNormPass.can_fuse(input_node, bn, ep) if not can_fuse: why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}") return False diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index b8362b2ebfd..1cadb79032f 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -8,18 +8,12 @@ from typing import Tuple import torch -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( - FuseBatchNormWithConvPass, -) -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import ( - FuseBatchNormWithLinearPass, -) +from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack.test.tester import RunPasses, Tester class TestBatchNormFusion(unittest.TestCase): - ConvPassStage = RunPasses([FuseBatchNormWithConvPass]) - LinearPassStage = RunPasses([FuseBatchNormWithLinearPass]) + PassStage = RunPasses([FuseBatchNormPass]) bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" def setUp(self): @@ -70,7 +64,7 @@ def test_fp32_conv_batch_norm_fusion(self): ) .export() .to_edge() - .run_passes(self.ConvPassStage) + .run_passes(self.PassStage) .check_count({self.bn_name: 1}) .run_method_and_compare_outputs() ) @@ -85,7 +79,7 @@ def test_q8_conv_batch_norm_fusion(self): .quantize() .export() .to_edge() - .run_passes(self.ConvPassStage) + .run_passes(self.PassStage) .check_count({self.bn_name: 1}) .run_method_and_compare_outputs() ) From 20afaa85a72de646d2cc1a6cacad3fba45a56782 Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 2 Jul 2025 12:55:51 -0700 Subject: [PATCH 7/7] Enable BatchNorm fusion for Linear with bias=False --- backends/xnnpack/_passes/fuse_batch_norm.py | 129 ++++++++++-------- .../test/passes/test_batch_norm_fusion.py | 23 ++-- 2 files changed, 84 insertions(+), 68 deletions(-) diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py index 88a912670a5..a83be194e66 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm.py +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -38,23 +38,17 @@ class FuseBatchNormPass(XNNPACKPass): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph constant_placeholders_to_delete = set() - for node in graph.nodes: + for input_node in graph.nodes: # We want to discover a chain of conv -> batch_norm or linear -> batch_norm. - # Only proceed if the current node is a conv or linear node, and has a single - # user/successor. - is_conv = node.target == exir_ops.edge.aten.convolution.default - is_linear = node.target == exir_ops.edge.aten.linear.default + # Only proceed if the current node is a conv or linear, and has a single user/successor. + is_conv = input_node.target == exir_ops.edge.aten.convolution.default + is_linear = input_node.target == exir_ops.edge.aten.linear.default - if not (is_conv or is_linear): + if not (is_conv or is_linear) or len(input_node.users) != 1: continue - if len(node.users) != 1: - continue - - # Conv or linear op to fuse. - target_op = node - # The single user of the op must be batch_norm. If not, bail. - bn = list(target_op.users.keys())[0] + # The single user of the conv or linear node must be batch_norm. If not, bail. + bn = list(input_node.users.keys())[0] if ( bn.target != exir_ops.edge.aten.native_batch_norm.default and bn.target @@ -62,13 +56,13 @@ def call(self, graph_module: torch.fx.GraphModule): ): continue - if not self.can_fuse(target_op, bn, self.exported_program): + if not self.can_fuse(input_node, bn, self.exported_program): continue self._fuse_ops( graph_module, graph, - target_op, + input_node, bn, is_conv, constant_placeholders_to_delete, @@ -81,38 +75,38 @@ def call(self, graph_module: torch.fx.GraphModule): delete_constant_placeholder(self.exported_program, node) graph_module.recompile() - # To Regenerate metadata and shape information, retrace module. + # To regenerate metadata and shape information, retrace module. graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) @staticmethod def can_fuse( - target_op: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram + input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram ) -> bool: """ - Determine whether a batchnorm node can be fused with a preceding conv or linear node. + Determine whether a BatchNorm node can be fused with the preceding convolution or linear node. """ - # All the users of batchnorm node must be getitem ops. batchnorm - # returns a 3-element tuple. Each user must only access the first - # element of the tuple. + # All users of the batch_norm node must be getitem ops. + # batch_norm returns a 3-element tuple. + # Each user must only access the first element of the tuple. if [ (user.target == operator.getitem and user.args[1] == 0) for user in bn.users ].count(False): return False - target_op_weights = target_op.args[1] + input_node_weights = input_node.args[1] bn_weights = bn.args[1] - # Check that the weights for conv or linear and batchnorm are both params. - if not isinstance(target_op_weights, torch.fx.Node) or not isinstance( + # Check that the weights for conv or linear and batch_norm are both params. + if not isinstance(input_node_weights, torch.fx.Node) or not isinstance( bn_weights, torch.fx.Node ): return False if [ - is_param_node(program, node) for node in {target_op_weights, bn_weights} + is_param_node(program, node) for node in {input_node_weights, bn_weights} ].count(False): return False @@ -122,32 +116,45 @@ def _fuse_ops( self, graph_module: torch.fx.GraphModule, graph: torch.fx.Graph, - target_op: torch.fx.Node, + input_node: torch.fx.Node, bn: torch.fx.Node, is_conv: bool, constant_placeholders_to_delete: set, ) -> None: """ - Fuse a BatchNorm into the preceding conv or linear op. - Update the fused op's weight and bias, rewire users of the BatchNorm's output, and remove the BatchNorm node. + Fuse a BatchNorm node into the preceding convolution or linear node. + Update the fused node's weight and bias, rewire users of the BatchNorm output, + and remove the BatchNorm node. """ if is_conv: - assert len(target_op.args) == 9 - else: # Linear path: (input, weight, bias). - assert len(target_op.args) == 3 + assert len(input_node.args) == 9 + has_bias_arg = True + else: + # Otherwise, this is a linear node. + # Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias). + assert len(input_node.args) in (2, 3) + has_bias_arg = len(input_node.args) == 3 # Get the weight and bias parameters from the conv or linear op. - target_op_weight = get_param_tensor(self.exported_program, target_op.args[1]) - target_op_weight_name = get_tensor_name( - self.exported_program, target_op.args[1] + input_node_weight = get_param_tensor(self.exported_program, input_node.args[1]) + input_node_weight_name = get_tensor_name( + self.exported_program, input_node.args[1] ) - assert target_op_weight is not None + assert input_node_weight is not None - target_op_bias = get_param_tensor(self.exported_program, target_op.args[2]) - target_op_bias_name = get_tensor_name(self.exported_program, target_op.args[2]) + if has_bias_arg: + input_node_bias = get_param_tensor( + self.exported_program, input_node.args[2] + ) + input_node_bias_name = get_tensor_name( + self.exported_program, input_node.args[2] + ) + else: + input_node_bias = None + input_node_bias_name = "" - # Get the parameters from the batchnorm op. + # Get the parameters from the batch_norm op. assert ( bn.target == exir_ops.edge.aten.native_batch_norm.default and len(bn.args) == 8 @@ -169,10 +176,10 @@ def _fuse_ops( # as an arg). eps = bn.args[-1] - # Compute the updated weight and bias after fusing conv or linear op with batchnorm op. + # Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op. fuse_args = ( - target_op_weight, - target_op_bias, + input_node_weight, + input_node_bias, running_mean, running_var, eps, @@ -181,23 +188,24 @@ def _fuse_ops( ) if is_conv: - is_transpose = target_op.args[6] + is_transpose = input_node.args[6] fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose) - else: # Linear path. + else: + # Otherwise, this is a linear node. fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args) - fused_weight_name = (target_op_weight_name + "_fused_bn").replace(".", "_") - if target_op_bias_name == "": - fused_bias_name = (target_op_weight_name + "_bias_fused_bn").replace( + fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_") + if input_node_bias_name == "": + fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace( ".", "_" ) else: - fused_bias_name = (target_op_bias_name + "_fused_bn").replace(".", "_") + fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_") - # Modify the graph by updating the weight and bias of conv or linear op + # Modify the graph by updating the weight and bias of the conv or linear op # with the fused weight and bias params, and replacing all the users - # of getitem(batchnorm) with the conv or linear op. - with graph.inserting_before(target_op.args[1]): + # of getitem(batch_norm) with the conv or linear op. + with graph.inserting_before(input_node.args[1]): fused_op_weight_node = create_constant_placeholder( exp_program=self.exported_program, graph=graph_module.graph, @@ -216,17 +224,24 @@ def _fuse_ops( else: fused_op_bias_node = None - # Replace weight and bias with the fused batchnorm values. - args = list(target_op.args) + # Replace the original weight and bias with the fused batch_norm values. + args = list(input_node.args) args[1] = fused_op_weight_node - args[2] = fused_op_bias_node - target_op.args = tuple(args) - # Remove any use of batchnorm from the graph + if has_bias_arg: + # Overwrite original bias with the fused bias. + args[2] = fused_op_bias_node + elif fused_op_bias_node is not None: + # Add the fused bias as a new argument if no bias had originally existed in the input_node. + args.append(fused_op_bias_node) + + input_node.args = tuple(args) + + # Remove any use of batch_norm from the graph. for user in bn.users.copy(): assert user.target == operator.getitem - user.replace_all_uses_with(target_op) + user.replace_all_uses_with(input_node) graph.erase_node(user) graph.erase_node(bn) - constant_placeholders_to_delete.update(target_op.args[1:3] + bn.args[1:5]) + constant_placeholders_to_delete.update(input_node.args[1:3] + bn.args[1:5]) diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index 1cadb79032f..a095fa236fe 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -41,10 +41,10 @@ def forward(self, x): return self.bn(y) class ModelLinearBN(torch.nn.Module): - def __init__(self, in_features, out_features): + def __init__(self, in_features, out_features, bias=True): super().__init__() op = torch.nn.Linear - self.linear = op(in_features, out_features) + self.linear = op(in_features, out_features, bias=bias) self.bn = torch.nn.BatchNorm1d(out_features) self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats @@ -109,16 +109,17 @@ def forward(self, x): ) def test_fp32_linear_batch_norm_fusion(self): - ( - Tester( - self.ModelLinearBN(2, 2).eval(), - (torch.randn(2, 2),), + for bias in [True, False]: + ( + Tester( + self.ModelLinearBN(2, 2, bias).eval(), + (torch.randn(2, 2),), + ) + .export() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() ) - .export() - .to_edge_transform_and_lower() - .check_count({self.bn_name: 1}) - .run_method_and_compare_outputs() - ) def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self): """