diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 5d7e388ef0a..fe81d817491 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -21,9 +21,7 @@ ) from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( - FuseBatchNormWithConvPass, -) +from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass @@ -60,7 +58,7 @@ def __init__( ConvertToLinearPass, ConvertToSDPAPass, ConstPropPass, - FuseBatchNormWithConvPass, + FuseBatchNormPass, FuseActivationPass, DecomposeConcatenate, RemoveGetItemPass, diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py new file mode 100644 index 00000000000..a83be194e66 --- /dev/null +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass + +from executorch.backends.xnnpack.utils.utils import ( + get_param_tensor, + get_tensor_name, + is_param_node, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind + +from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights + + +class FuseBatchNormPass(XNNPACKPass): + """ + BatchNorm can be implemented using 1x1 Depthwise Convolution. However, doing so will increase + memory usage since we serialize new weights to represent the convolution. In most cases, + BatchNorm is used after convolution or linear. The 1x1 depthwise convolution can then be fused + with the previous convolution. For linear cases, BatchNorm can be folded into the previous linear layer. + """ + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + constant_placeholders_to_delete = set() + for input_node in graph.nodes: + # We want to discover a chain of conv -> batch_norm or linear -> batch_norm. + # Only proceed if the current node is a conv or linear, and has a single user/successor. + is_conv = input_node.target == exir_ops.edge.aten.convolution.default + is_linear = input_node.target == exir_ops.edge.aten.linear.default + + if not (is_conv or is_linear) or len(input_node.users) != 1: + continue + + # The single user of the conv or linear node must be batch_norm. If not, bail. + bn = list(input_node.users.keys())[0] + if ( + bn.target != exir_ops.edge.aten.native_batch_norm.default + and bn.target + != exir_ops.edge.aten._native_batch_norm_legit_no_training.default + ): + continue + + if not self.can_fuse(input_node, bn, self.exported_program): + continue + + self._fuse_ops( + graph_module, + graph, + input_node, + bn, + is_conv, + constant_placeholders_to_delete, + ) + + if len(constant_placeholders_to_delete) > 0: + graph_module.graph.eliminate_dead_code() + for node in constant_placeholders_to_delete: + if (node is not None) and (len(node.users) == 0): + delete_constant_placeholder(self.exported_program, node) + + graph_module.recompile() + # To regenerate metadata and shape information, retrace module. + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) + + @staticmethod + def can_fuse( + input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram + ) -> bool: + """ + Determine whether a BatchNorm node can be fused with the preceding convolution or linear node. + """ + + # All users of the batch_norm node must be getitem ops. + # batch_norm returns a 3-element tuple. + # Each user must only access the first element of the tuple. + if [ + (user.target == operator.getitem and user.args[1] == 0) for user in bn.users + ].count(False): + return False + + input_node_weights = input_node.args[1] + bn_weights = bn.args[1] + + # Check that the weights for conv or linear and batch_norm are both params. + if not isinstance(input_node_weights, torch.fx.Node) or not isinstance( + bn_weights, torch.fx.Node + ): + return False + + if [ + is_param_node(program, node) for node in {input_node_weights, bn_weights} + ].count(False): + return False + + return True + + def _fuse_ops( + self, + graph_module: torch.fx.GraphModule, + graph: torch.fx.Graph, + input_node: torch.fx.Node, + bn: torch.fx.Node, + is_conv: bool, + constant_placeholders_to_delete: set, + ) -> None: + """ + Fuse a BatchNorm node into the preceding convolution or linear node. + Update the fused node's weight and bias, rewire users of the BatchNorm output, + and remove the BatchNorm node. + """ + + if is_conv: + assert len(input_node.args) == 9 + has_bias_arg = True + else: + # Otherwise, this is a linear node. + # Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias). + assert len(input_node.args) in (2, 3) + has_bias_arg = len(input_node.args) == 3 + + # Get the weight and bias parameters from the conv or linear op. + input_node_weight = get_param_tensor(self.exported_program, input_node.args[1]) + input_node_weight_name = get_tensor_name( + self.exported_program, input_node.args[1] + ) + assert input_node_weight is not None + + if has_bias_arg: + input_node_bias = get_param_tensor( + self.exported_program, input_node.args[2] + ) + input_node_bias_name = get_tensor_name( + self.exported_program, input_node.args[2] + ) + else: + input_node_bias = None + input_node_bias_name = "" + + # Get the parameters from the batch_norm op. + assert ( + bn.target == exir_ops.edge.aten.native_batch_norm.default + and len(bn.args) == 8 + ) or ( + bn.target == exir_ops.edge.aten._native_batch_norm_legit_no_training.default + and len(bn.args) == 7 + ) + bn_weight = get_param_tensor(self.exported_program, bn.args[1]) + bn_bias = get_param_tensor(self.exported_program, bn.args[2]) + + running_mean = get_param_tensor(self.exported_program, bn.args[3]) + assert running_mean is not None + + running_var = get_param_tensor(self.exported_program, bn.args[4]) + assert running_var is not None + + # args[7] for native_batch_norm, but args[6] for + # _native_batch_norm_legit_no_training (which doesn't have training + # as an arg). + eps = bn.args[-1] + + # Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op. + fuse_args = ( + input_node_weight, + input_node_bias, + running_mean, + running_var, + eps, + bn_weight, + bn_bias, + ) + + if is_conv: + is_transpose = input_node.args[6] + fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose) + else: + # Otherwise, this is a linear node. + fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args) + + fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_") + if input_node_bias_name == "": + fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace( + ".", "_" + ) + else: + fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_") + + # Modify the graph by updating the weight and bias of the conv or linear op + # with the fused weight and bias params, and replacing all the users + # of getitem(batch_norm) with the conv or linear op. + with graph.inserting_before(input_node.args[1]): + fused_op_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_weight_name, + data=fused_weight, + ) + if fused_bias is not None: + fused_op_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_bias_name, + data=fused_bias, + ) + else: + fused_op_bias_node = None + + # Replace the original weight and bias with the fused batch_norm values. + args = list(input_node.args) + args[1] = fused_op_weight_node + + if has_bias_arg: + # Overwrite original bias with the fused bias. + args[2] = fused_op_bias_node + elif fused_op_bias_node is not None: + # Add the fused bias as a new argument if no bias had originally existed in the input_node. + args.append(fused_op_bias_node) + + input_node.args = tuple(args) + + # Remove any use of batch_norm from the graph. + for user in bn.users.copy(): + assert user.target == operator.getitem + user.replace_all_uses_with(input_node) + graph.erase_node(user) + + graph.erase_node(bn) + constant_placeholders_to_delete.update(input_node.args[1:3] + bn.args[1:5]) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py deleted file mode 100644 index 6f31fe698ba..00000000000 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import operator - -import torch -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - delete_constant_placeholder, -) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass - -from executorch.backends.xnnpack.utils.utils import ( - get_param_tensor, - get_tensor_name, - is_param_node, -) -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult -from torch.export.graph_signature import InputKind - -from torch.nn.utils.fusion import fuse_conv_bn_weights - - -class FuseBatchNormWithConvPass(XNNPACKPass): - """ - Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase - memory usage since we serialize new weights to represent the convolution. In most cases, - Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused - with the previous convolution - """ - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - constant_placeholders_to_delete = set() - for conv in graph.nodes: - # We want to discover a chain of conv -> batch_norm. - # Only proceed if the current node is a conv node, and has a single - # user/successor. - if ( - conv.target != exir_ops.edge.aten.convolution.default - or len(conv.users) != 1 - ): - continue - - # The single user of conv op must be batch_norm. If not, bail. - bn = list(conv.users.keys())[0] - if ( - bn.target != exir_ops.edge.aten.native_batch_norm.default - and bn.target - != exir_ops.edge.aten._native_batch_norm_legit_no_training.default - ): - continue - - if not self.can_fuse(conv, bn, self.exported_program): - continue - - # Get the parameters from conv op - assert len(conv.args) == 9 - - conv_weight = get_param_tensor(self.exported_program, conv.args[1]) - conv_weight_name = get_tensor_name(self.exported_program, conv.args[1]) - assert conv_weight is not None - - conv_bias = get_param_tensor(self.exported_program, conv.args[2]) - conv_bias_name = get_tensor_name(self.exported_program, conv.args[2]) - - # Get the parameters from the batchnorm op - assert ( - bn.target == exir_ops.edge.aten.native_batch_norm.default - and len(bn.args) == 8 - ) or ( - bn.target - == exir_ops.edge.aten._native_batch_norm_legit_no_training.default - and len(bn.args) == 7 - ) - bn_weight = get_param_tensor(self.exported_program, bn.args[1]) - bn_bias = get_param_tensor(self.exported_program, bn.args[2]) - - running_mean = get_param_tensor(self.exported_program, bn.args[3]) - assert running_mean is not None - - running_var = get_param_tensor(self.exported_program, bn.args[4]) - assert running_var is not None - - # args[7] for native_batch_norm, but args[6] for - # _native_batch_norm_legit_no_training (which doesn't have training - # as an arg) - eps = bn.args[-1] - - is_transpose = conv.args[6] - # Compute the updated weight and bias after fusing conv op - # with batchnorm op. - fused_weight, fused_bias = fuse_conv_bn_weights( - conv_weight, - conv_bias, - running_mean, - running_var, - eps, - bn_weight, - bn_bias, - is_transpose, - ) - fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_") - if conv_bias_name == "": - fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace( - ".", "_" - ) - else: - fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_") - - # Modify the graph by updating the weight and bias of conv op - # with the fused weight and bias params, and replacing all the users - # of getitem(batchnorm) with the conv op. - with graph.inserting_before(conv.args[1]): - fused_conv_weight_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=fused_weight_name, - data=fused_weight, - ) - if fused_bias is not None: - fused_conv_bias_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=fused_bias_name, - data=fused_bias, - ) - else: - fused_conv_bias_node = None - - conv.args = ( - conv.args[0], - fused_conv_weight_node, - fused_conv_bias_node, - *conv.args[3:], - ) - - # Remove any use of batchnorm from the graph - for user in bn.users.copy(): - assert user.target == operator.getitem - user.replace_all_uses_with(conv) - graph.erase_node(user) - - graph.erase_node(bn) - constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5]) - - if len(constant_placeholders_to_delete) > 0: - graph_module.graph.eliminate_dead_code() - for node in constant_placeholders_to_delete: - if (node is not None) and (len(node.users) == 0): - delete_constant_placeholder(self.exported_program, node) - - graph_module.recompile() - # To Regenerate meta data and shape information, retrace module - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) - - @staticmethod - def can_fuse( - conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram - ) -> bool: - """ - Determine whether a batch norm node can be fused with a preceding conv node. - """ - - # All the users of batchnorm node must be getitem ops. batchnorm - # returns a 3-element tuple. Each user must only access the first - # element of the tuple. - if [ - (user.target == operator.getitem and user.args[1] == 0) for user in bn.users - ].count(False): - return False - - conv_weights = conv.args[1] - bn_weights = bn.args[1] - - # Check that the weights for conv and batchnorm are both params - if not isinstance(conv_weights, torch.fx.Node) or not isinstance( - bn_weights, torch.fx.Node - ): - return False - - if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count( - False - ): - return False - - return True diff --git a/backends/xnnpack/partition/config/node_configs.py b/backends/xnnpack/partition/config/node_configs.py index 23acfbfb8c4..4659ea05a0f 100644 --- a/backends/xnnpack/partition/config/node_configs.py +++ b/backends/xnnpack/partition/config/node_configs.py @@ -9,9 +9,7 @@ from typing import List, Optional import torch -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( - FuseBatchNormWithConvPass, -) +from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, XNNPartitionerConfig, @@ -35,20 +33,20 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False bn = node - conv = node.all_input_nodes[0] + input_node = node.all_input_nodes[0] - if conv.op != "call_function": + if input_node.op != "call_function": return False - conv_name = format_target_name(conv.target.__name__) # pyre-ignore + input_name = format_target_name(input_node.target.__name__) # pyre-ignore - if conv_name not in ["convolution.default"]: - why(node, f"Invalid conv target {conv_name}") + if input_name not in ["convolution.default", "linear.default"]: + why(node, f"Invalid input target {input_name.split('.')[0]}") return False - can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep) + can_fuse = FuseBatchNormPass.can_fuse(input_node, bn, ep) if not can_fuse: - why(node, "BatchNorm cannot be fused with Convolution") + why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}") return False return True diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index 70c93c3751b..a095fa236fe 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -8,14 +8,12 @@ from typing import Tuple import torch -from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( - FuseBatchNormWithConvPass, -) +from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack.test.tester import RunPasses, Tester class TestBatchNormFusion(unittest.TestCase): - PassStage = RunPasses([FuseBatchNormWithConvPass]) + PassStage = RunPasses([FuseBatchNormPass]) bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" def setUp(self): @@ -42,7 +40,22 @@ def forward(self, x): y = y + y return self.bn(y) - def test_fp32_batch_norm_fusion(self): + class ModelLinearBN(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + op = torch.nn.Linear + self.linear = op(in_features, out_features, bias=bias) + self.bn = torch.nn.BatchNorm1d(out_features) + self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats + + def forward(self, x): + y = self.linear(x) + y = self.bn(y) + y = self.linear(y) + y = y + y + return self.bn(y) + + def test_fp32_conv_batch_norm_fusion(self): for transpose in [False, True]: ( Tester( @@ -56,7 +69,7 @@ def test_fp32_batch_norm_fusion(self): .run_method_and_compare_outputs() ) - def test_q8_batch_norm_fusion(self): + def test_q8_conv_batch_norm_fusion(self): for transpose in [False, True]: ( Tester( @@ -71,7 +84,7 @@ def test_q8_batch_norm_fusion(self): .run_method_and_compare_outputs() ) - def test_fp32_batch_norm_no_fusion_doesnt_partition(self): + def test_fp32_conv_batch_norm_no_fusion_doesnt_partition(self): """ We do not currently support standalone batch norms (i.e. batch norms that are not fused with a conv). This is planned, but until implemented, this test ensures @@ -94,3 +107,38 @@ def forward(self, x): .partition() .check_count({self.bn_name: 1}) ) + + def test_fp32_linear_batch_norm_fusion(self): + for bias in [True, False]: + ( + Tester( + self.ModelLinearBN(2, 2, bias).eval(), + (torch.randn(2, 2),), + ) + .export() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() + ) + + def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self): + """ + We do not currently support standalone batch norms (i.e. batch norms that are + not fused with a linear). This is planned, but until implemented, this test ensures + that we do not partition the standalone batch norm and then fail to lower. + """ + + class BN(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(2) + + def forward(self, x): + return self.bn(x) + + ( + Tester(BN(), (torch.randn(2, 2),)) + .export() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 1}) + )