From c0149942a85b2233ff1eecc1301927948d9e1874 Mon Sep 17 00:00:00 2001 From: kim hyun gyu Date: Tue, 29 Jul 2025 19:45:21 +0900 Subject: [PATCH 1/2] [Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion This PR introduces an operator fusion for the common `conv2d` followed by `reshape`, `add`, and `relu` sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage. 1. **Performance Improvement:** * **Reduced Kernel Launch Overhead:** Previously, `conv2d`, `reshape`, `add`, and `relu` each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g., `dnnl_fused_conv2d_bias_relu`), the overhead from multiple kernel launches is significantly reduced. This is evident from `src/runtime/contrib/dnnl/dnnl.cc:154-158`, where all operations are handled by a single `execute` call. * **Decreased Memory Bandwidth Consumption:** Intermediate results of individual operations (e.g., `conv_out`, `bias_add`) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time. 2. **Increased Efficiency:** * **Leveraging Compiler Optimizations:** By utilizing TVM's `FuseOpsByPattern` and `MergeCompositeFunctions` passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL. * **Simplified IR Module:** Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages. This fusion is achieved through a two-stage transformation within the TVM Relax framework: 1. **Pattern Recognition and Composite Function Creation (`FuseConv2dReshapeAddRelu` Pass):** * The `FuseConv2dReshapeAddRelu` class, registered as a `tvm.transform.module_pass`, transforms the `IRModule`. * The `_conv2d_reshape_add_relu_pattern()` helper function defines the specific sequence: `conv2d` -> `reshape` (applied to bias) -> `add` -> `relu` using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (`data`, `weight`, `bias`, `shape`) using `wildcard()` and identifying operation sequence with `is_op()`. * The `relax.transform.FuseOpsByPattern` pass identifies this pattern in the input `IRModule`. Upon detection, the operation sequence is encapsulated into a new Relax function with `{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}` attributes, marking it as a logical "composite" unit. 2. **Composite Function Merging and Codegen Attribute Assignment (`MergeCompositeFunctions` Pass):** * Following the `FuseConv2dReshapeAddRelu` pass, the `MergeCompositeFunctions` pass is applied via `tvm.ir.transform.Sequential`. * This pass identifies functions marked with the `Composite` attribute and transforms them into external functions bearing the `{"Codegen": "dnnl"}` attribute. This `Codegen` attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL. * Consequently, during graph execution, the fused function with the `Codegen` attribute will be mapped and executed by an optimized, single DNNL kernel, for instance, `dnnl_fused_conv2d_bias_relu` (defined in `src/runtime/contrib/dnnl/dnnl.cc:199-207`). This implementation successfully enables the fusion of the `conv2d + reshape + add + relu` pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM. --- To verify this fusion, you can directly run the specific test case: python tests/python/relax/test_conv2d_reshape_add_relu.py --- .lesshst | 1 + python/tvm/relax/transform/__init__.py | 3 + .../transform/fuse_conv2d_reshape_add_relu.py | 115 ++++++++++++++++++ .../relax/test_conv2d_reshape_add_relu.py | 75 ++++++++++++ 4 files changed, 194 insertions(+) create mode 100644 .lesshst create mode 100644 python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py create mode 100644 tests/python/relax/test_conv2d_reshape_add_relu.py diff --git a/.lesshst b/.lesshst new file mode 100644 index 000000000000..4d1c30b7a584 --- /dev/null +++ b/.lesshst @@ -0,0 +1 @@ +.less-history-file: diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 724921e5fee7..833810e70e0d 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -97,5 +97,8 @@ from .fold_batch_norm_to_conv2d_for_inference import FoldBatchnormToConv2D from .remove_redundant_reshape import RemoveRedundantReshape +# Import the specific fusion pass for Conv2d-Reshape-Add-ReLU. +from .fuse_conv2d_reshape_add_relu import FuseConv2dReshapeAddRelu + # Import to register the legalization functions. from . import legalize_ops diff --git a/python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py b/python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py new file mode 100644 index 000000000000..f5abc860a51b --- /dev/null +++ b/python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module provides a TVM Relax pass for fusing Conv2d-Reshape-Add-ReLU pattern.""" + +import tvm +from tvm import IRModule, relax +from tvm.relax.dpl.pattern import is_op, wildcard + +# Define a TVM module pass for fusing specific operations. +# @tvm.transform.module_pass decorates a class to turn it into a TVM IRModule pass. +# opt_level=0 means this pass can be run at any optimization level. +# name="FuseConv2dReshapeAddRelu" gives a descriptive name to the pass. + + +@tvm.transform.module_pass(opt_level=0, name="FuseConv2dReshapeAddRelu") +class FuseConv2dReshapeAddRelu: + """A Relax pass that fuses the Conv2d-Reshape-Add-ReLU pattern into a composite function.""" + + # The main transformation method that applies the pass to an IRModule. + # mod: The input IRModule to be transformed. + # _ctx: PassContext (unused in this specific pass but required by the decorator). + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Transforms the input IRModule by applying the Conv2d-Reshape-Add-ReLU fusion. + + Parameters + ---------- + mod : IRModule + The input IRModule to be transformed. + _ctx : tvm.transform.PassContext + The pass context (unused in this specific pass but required by the decorator). + + Returns + ------- + IRModule + The transformed IRModule with the fused pattern. + """ + # Apply the FuseOpsByPattern transformation. + # This pass identifies specific operator patterns in the IRModule + # and fuses them into a single composite function. + mod = relax.transform.FuseOpsByPattern( + # Define the patterns to fuse. It's a list of tuples: + # ("composite_function_name", pattern_root, annotations, check_function) + # "dnnl.conv2d_reshape_add_relu" is the name given to the fused operation, + # indicating it's suitable for DNNL backend. + [("dnnl.conv2d_reshape_add_relu", *_conv2d_reshape_add_relu_pattern())], + # bind_constants=False means that constants in the pattern (like shapes) + # are not treated as part of the pattern to be matched, allowing for more flexibility. + bind_constants=False, + )(mod) + + # Return the transformed IRModule. + return mod + + +# Helper function to define the operator fusion pattern for Conv2d-Reshape-Add-ReLU. +# This function uses TVM's declarative pattern language (DPL). +def _conv2d_reshape_add_relu_pattern(): + # Define wildcard placeholders for the input tensors. + # 'wildcard()' matches any Relax expression. + data = wildcard() + weight = wildcard() + bias = wildcard() + shape = wildcard() # Wildcard for the target shape of the reshape operation + + # Define the sequence of operations in the pattern: + # 1. Convolution (relax.nn.conv2d) + # varg_default_wildcard=True means that any variadic arguments (like strides, padding) + # will also be matched by wildcards, making the pattern more general. + conv_out = is_op("relax.nn.conv2d")(data, weight, varg_default_wildcard=True) + # 2. Reshape (relax.reshape) + # This matches a reshape operation applied to the 'bias' tensor with any 'shape'. + reshaped_bias = is_op("relax.reshape")(bias, shape) + # 3. Addition (relax.add) + # This matches an add operation where 'conv_out' and 'reshaped_bias' are inputs. + add_out = is_op("relax.add")(conv_out, reshaped_bias) + # 4. ReLU (relax.nn.relu) + # This matches a ReLU operation applied to the output of the add operation. + relu_out = is_op("relax.nn.relu")(add_out) + + # Define annotations for the pattern. + # These map internal names (keys) to the matched Relax expressions (values). + # This is useful for debugging and for custom check functions. + annotations = { + "conv_out": conv_out, + "reshaped_bias": reshaped_bias, + "add_out": add_out, + "relu_out": relu_out, + } + + # Define a custom check function for the pattern. + # This function is executed after a potential match is found. + # It can be used to add more specific conditions for the fusion. + # In this case, 'return True' means it always matches if the structure is found. + def _check(_context): + """A check function for the pattern (currently always returns True).""" + return True + + # Return the root of the pattern, the annotations, and the check function. + # The 'relu_out' is the final output of the sequence being matched. + return relu_out, annotations, _check diff --git a/tests/python/relax/test_conv2d_reshape_add_relu.py b/tests/python/relax/test_conv2d_reshape_add_relu.py new file mode 100644 index 000000000000..164c7160a21b --- /dev/null +++ b/tests/python/relax/test_conv2d_reshape_add_relu.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +from tvm.relax.transform import FuseConv2dReshapeAddRelu +from tvm.script import relax as R + + +def test_transform_pass(): + + # Define the initial IRModule + @tvm.script.ir_module + class TestModule: + @R.function + def main( + data: R.Tensor((1, 3, 224, 224), dtype="float32"), + weight: R.Tensor((64, 3, 3, 3), dtype="float32"), + bias: R.Tensor((64,), dtype="float32"), + ): + with R.dataflow(): + conv_out = R.nn.conv2d(data, weight) + bias_reshaped = R.reshape(bias, [1, 64, 1, 1]) + bias_add = R.add(conv_out, bias_reshaped) + relu_out = R.nn.relu(bias_add) + R.output(relu_out) + return relu_out + + print(TestModule) + + # Step 1: Apply the FuseConv2dReshapeAddRelu pass + # This pass identifies the fusion pattern (conv2d-reshape-add-relu) + # and encapsulates it into a new Relax function with "Composite" attribute. + fused_mod = FuseConv2dReshapeAddRelu()(TestModule) + print("=== IR after Step 1 (FuseConv2dReshapeAddRelu) ===") + print(fused_mod) + + # Step 2: Apply Sequential passes including MergeCompositeFunctions + # MergeCompositeFunctions takes functions marked with "Composite" + # and transforms them into functions with a "Codegen" attribute, + # indicating they should be offloaded to an external backend (e.g., DNNL). + final_mod = tvm.ir.transform.Sequential( + [ + relax.transform.FuseConv2dReshapeAddRelu(), + relax.transform.MergeCompositeFunctions(), + ] + )(TestModule) + + print("=== IR after Final Fusion (Sequential Passes) ===") + print(final_mod) + + # Check attributes of functions in the final module + # This helps confirm if "Codegen" attribute was successfully added to the fused function. + print("=== Function Attributes in Final IR ===") + for name, func in final_mod.functions.items(): + if hasattr(func, "attrs") and func.attrs: + print(f"Function {name} attributes:", func.attrs) + + +if __name__ == "__main__": + test_transform_pass() From b2404e79c0ce924805312a210d102d8081be1a58 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Wed, 30 Jul 2025 14:37:40 +0900 Subject: [PATCH 2/2] [Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion This PR introduces an operator fusion for the common conv2d followed by reshape, add, and relu sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage. Performance Improvement: Reduced Kernel Launch Overhead: Previously, conv2d, reshape, add, and relu each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g., dnnl_fused_conv2d_bias_relu), the overhead from multiple kernel launches is significantly reduced. This is evident from src/runtime/contrib/dnnl/dnnl.cc:154-158, where all operations are handled by a single execute call. Decreased Memory Bandwidth Consumption: Intermediate results of individual operations (e.g., conv_out, bias_add) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time. Increased Efficiency: Leveraging Compiler Optimizations: By utilizing TVM's FuseOpsByPattern and MergeCompositeFunctions passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL. Simplified IR Module: Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages. This fusion is achieved through a two-stage transformation within the TVM Relax framework: Pattern Recognition and Composite Function Creation (FuseConv2dReshapeAddRelu Pass): The FuseConv2dReshapeAddRelu class, registered as a tvm.transform.module_pass, transforms the IRModule. The _conv2d_reshape_add_relu_pattern() helper function defines the specific sequence: conv2d -> reshape (applied to bias) -> add -> relu using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (data, weight, bias, shape) using wildcard() and identifying operation sequence with is_op(). The relax.transform.FuseOpsByPattern pass identifies this pattern in the input IRModule. Upon detection, the operation sequence is encapsulated into a new Relax function with {"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True} attributes, marking it as a logical "composite" unit. Composite Function Merging and Codegen Attribute Assignment (MergeCompositeFunctions Pass): Following the FuseConv2dReshapeAddRelu pass, the MergeCompositeFunctions pass is applied via tvm.ir.transform.Sequential. This pass identifies functions marked with the Composite attribute and transforms them into external functions bearing the {"Codegen": "dnnl"} attribute. This Codegen attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL. Consequently, during graph execution, the fused function with the Codegen attribute will be mapped and executed by an optimized, single DNNL kernel, for instance, dnnl_fused_conv2d_bias_relu (defined in src/runtime/contrib/dnnl/dnnl.cc:199-207). This implementation successfully enables the fusion of the conv2d + reshape + add + relu pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM. To verify this fusion, you can directly run the specific test case: python tests/python/relax/test_conv2d_reshape_add_relu.py --- .lesshst | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .lesshst diff --git a/.lesshst b/.lesshst deleted file mode 100644 index 4d1c30b7a584..000000000000 --- a/.lesshst +++ /dev/null @@ -1 +0,0 @@ -.less-history-file: