Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,8 @@
from .fold_batch_norm_to_conv2d_for_inference import FoldBatchnormToConv2D
from .remove_redundant_reshape import RemoveRedundantReshape

# Import the specific fusion pass for Conv2d-Reshape-Add-ReLU.
from .fuse_conv2d_reshape_add_relu import FuseConv2dReshapeAddRelu

# Import to register the legalization functions.
from . import legalize_ops
115 changes: 115 additions & 0 deletions python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module provides a TVM Relax pass for fusing Conv2d-Reshape-Add-ReLU pattern."""

import tvm
from tvm import IRModule, relax
from tvm.relax.dpl.pattern import is_op, wildcard

# Define a TVM module pass for fusing specific operations.
# @tvm.transform.module_pass decorates a class to turn it into a TVM IRModule pass.
# opt_level=0 means this pass can be run at any optimization level.
# name="FuseConv2dReshapeAddRelu" gives a descriptive name to the pass.


@tvm.transform.module_pass(opt_level=0, name="FuseConv2dReshapeAddRelu")
class FuseConv2dReshapeAddRelu:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if transform.FuseOps will fuse them, I guess it might work

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongwww
Excellent point! However, after checking the actual implementation, I've confirmed that the generic FuseOps cannot handle this specific pattern.

Summary

The generic relax.transform.FuseOps pass is currently unable to fuse the common conv2d + bias + activation pattern when imported from PyTorch. The root cause is that the PyTorch frontend generates a conv2d -> reshape -> add sequence for the bias term, which the existing pattern matcher in FuseOps does not recognize. This leaves a critical, common pattern unoptimized.

The Pattern Generated by the PyTorch Frontend

When handling a torch.nn.Conv2d layer with bias=True, the PyTorch frontend consistently generates a reshape + add pattern for the bias. This is not specific to Conv2d and is standard behavior for other convolution types as well:

Conv1d: See test_frontend_from_exported_program.py:1752-1753

Conv2d: See test_frontend_from_fx.py:269-270

Conv3d: See test_frontend_from_exported_program.py:3822-3823

Limitation of TVM's Current Pattern Matching

The pattern designed to fuse bias and activation, make_fused_bias_activation_pattern, is defined in pattern.py:1179-1181. This function is currently implemented to match only a simple relax.add operation following the convolution. It cannot see past the reshape operation inserted by the frontend, thus failing to match the sequence.

Proof by Code: A Reproducible Example
The following test case demonstrates that FuseOps fails to fuse this pattern.

import torch
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

# 1. PyTorch Conv2d model with bias and ReLU
class Conv2dWithBias(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.conv(x))

# 2. Trace and convert the model to TVM Relax IR
model = Conv2dWithBias()
graph_model = torch.fx.symbolic_trace(model)
input_info = [([1, 3, 10, 10], "float32")]
mod = from_fx(graph_model, input_info)

print("### Original Relax IR (Before FuseOps):")
print(mod)

# 3. Apply the generic FuseOps pass
fused_mod = relax.transform.FuseOps()(mod)

print("\n### Relax IR After Applying FuseOps:")
print(fused_mod)

Execution Results

Converted IR (Before FuseOps): A sequence of four separate operations is generated: conv2d → reshape → add → relu.

IR After FuseOps: The IR remains completely unchanged, confirming that the fusion failed.

This failure is a direct result of the pattern in pattern.py:1179-1181 matching only relax.add and not the reshape + add sequence.

Conclusion and Proposal

The generic FuseOps pass cannot handle this frontend-specific pattern, leaving a common PyTorch model structure (conv2d + bias + relu) unoptimized.

Therefore, a specialized pass like FuseConv2dReshapeAddRelu is essential to correctly identify and fuse this pattern. This targeted pass is necessary to bridge the gap between the PyTorch frontend's IR generation and TVM's optimization capabilities, unlocking performance for a wide range of models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could extend FuseOps to handle this - that way, other cases could benefit as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a moment, I'll get to it.

"""A Relax pass that fuses the Conv2d-Reshape-Add-ReLU pattern into a composite function."""

# The main transformation method that applies the pass to an IRModule.
# mod: The input IRModule to be transformed.
# _ctx: PassContext (unused in this specific pass but required by the decorator).
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""Transforms the input IRModule by applying the Conv2d-Reshape-Add-ReLU fusion.

Parameters
----------
mod : IRModule
The input IRModule to be transformed.
_ctx : tvm.transform.PassContext
The pass context (unused in this specific pass but required by the decorator).

Returns
-------
IRModule
The transformed IRModule with the fused pattern.
"""
# Apply the FuseOpsByPattern transformation.
# This pass identifies specific operator patterns in the IRModule
# and fuses them into a single composite function.
mod = relax.transform.FuseOpsByPattern(
# Define the patterns to fuse. It's a list of tuples:
# ("composite_function_name", pattern_root, annotations, check_function)
# "dnnl.conv2d_reshape_add_relu" is the name given to the fused operation,
# indicating it's suitable for DNNL backend.
[("dnnl.conv2d_reshape_add_relu", *_conv2d_reshape_add_relu_pattern())],
# bind_constants=False means that constants in the pattern (like shapes)
# are not treated as part of the pattern to be matched, allowing for more flexibility.
bind_constants=False,
)(mod)

# Return the transformed IRModule.
return mod


# Helper function to define the operator fusion pattern for Conv2d-Reshape-Add-ReLU.
# This function uses TVM's declarative pattern language (DPL).
def _conv2d_reshape_add_relu_pattern():
# Define wildcard placeholders for the input tensors.
# 'wildcard()' matches any Relax expression.
data = wildcard()
weight = wildcard()
bias = wildcard()
shape = wildcard() # Wildcard for the target shape of the reshape operation

# Define the sequence of operations in the pattern:
# 1. Convolution (relax.nn.conv2d)
# varg_default_wildcard=True means that any variadic arguments (like strides, padding)
# will also be matched by wildcards, making the pattern more general.
conv_out = is_op("relax.nn.conv2d")(data, weight, varg_default_wildcard=True)
# 2. Reshape (relax.reshape)
# This matches a reshape operation applied to the 'bias' tensor with any 'shape'.
reshaped_bias = is_op("relax.reshape")(bias, shape)
# 3. Addition (relax.add)
# This matches an add operation where 'conv_out' and 'reshaped_bias' are inputs.
add_out = is_op("relax.add")(conv_out, reshaped_bias)
# 4. ReLU (relax.nn.relu)
# This matches a ReLU operation applied to the output of the add operation.
relu_out = is_op("relax.nn.relu")(add_out)

# Define annotations for the pattern.
# These map internal names (keys) to the matched Relax expressions (values).
# This is useful for debugging and for custom check functions.
annotations = {
"conv_out": conv_out,
"reshaped_bias": reshaped_bias,
"add_out": add_out,
"relu_out": relu_out,
}

# Define a custom check function for the pattern.
# This function is executed after a potential match is found.
# It can be used to add more specific conditions for the fusion.
# In this case, 'return True' means it always matches if the structure is found.
def _check(_context):
"""A check function for the pattern (currently always returns True)."""
return True

# Return the root of the pattern, the annotations, and the check function.
# The 'relu_out' is the final output of the sequence being matched.
return relu_out, annotations, _check
75 changes: 75 additions & 0 deletions tests/python/relax/test_conv2d_reshape_add_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import relax
from tvm.relax.transform import FuseConv2dReshapeAddRelu
from tvm.script import relax as R


def test_transform_pass():

# Define the initial IRModule
@tvm.script.ir_module
class TestModule:
@R.function
def main(
data: R.Tensor((1, 3, 224, 224), dtype="float32"),
weight: R.Tensor((64, 3, 3, 3), dtype="float32"),
bias: R.Tensor((64,), dtype="float32"),
):
with R.dataflow():
conv_out = R.nn.conv2d(data, weight)
bias_reshaped = R.reshape(bias, [1, 64, 1, 1])
bias_add = R.add(conv_out, bias_reshaped)
relu_out = R.nn.relu(bias_add)
R.output(relu_out)
return relu_out

print(TestModule)

# Step 1: Apply the FuseConv2dReshapeAddRelu pass
# This pass identifies the fusion pattern (conv2d-reshape-add-relu)
# and encapsulates it into a new Relax function with "Composite" attribute.
fused_mod = FuseConv2dReshapeAddRelu()(TestModule)
print("=== IR after Step 1 (FuseConv2dReshapeAddRelu) ===")
print(fused_mod)

# Step 2: Apply Sequential passes including MergeCompositeFunctions
# MergeCompositeFunctions takes functions marked with "Composite"
# and transforms them into functions with a "Codegen" attribute,
# indicating they should be offloaded to an external backend (e.g., DNNL).
final_mod = tvm.ir.transform.Sequential(
[
relax.transform.FuseConv2dReshapeAddRelu(),
relax.transform.MergeCompositeFunctions(),
]
)(TestModule)

print("=== IR after Final Fusion (Sequential Passes) ===")
print(final_mod)

# Check attributes of functions in the final module
# This helps confirm if "Codegen" attribute was successfully added to the fused function.
print("=== Function Attributes in Final IR ===")
for name, func in final_mod.functions.items():
if hasattr(func, "attrs") and func.attrs:
print(f"Function {name} attributes:", func.attrs)


if __name__ == "__main__":
test_transform_pass()