From 8191760970c1faecf7bc9eafbab1cbf60d83f0dd Mon Sep 17 00:00:00 2001 From: xlchen <792512955@qq.com> Date: Fri, 25 Jul 2025 15:18:56 +0800 Subject: [PATCH 1/2] [Relax] Fix issue in fuse concat ops by pattern --- src/relax/transform/fuse_ops.cc | 9 ++- .../test_transform_fuse_ops_by_pattern.py | 73 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 0828e9c81c21..211430985ac3 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -427,10 +427,17 @@ class FunctionCreator : public ExprMutator { } for (const Expr& arg : call->args) { - CheckDefAndUpdateParam(arg); if (GetStructInfoAs(arg) != nullptr) { // The argument is fully referenced. Thus we remove it from the mapping. partially_used_tuple_params_.erase(arg.get()); + const Tuple& tup_args = Downcast(arg); + for (const Expr& tup_arg : tup_args->fields) { + CheckDefAndUpdateParam(tup_arg); + ICHECK(GetStructInfoAs(tup_arg) == nullptr); + } + } + else { + CheckDefAndUpdateParam(arg); } } } diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 999879e75184..2219c01ccb1e 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -26,6 +26,7 @@ is_tuple_get_item, make_fused_bias_activation_pattern, wildcard, + is_tuple, ) from tvm.relax.transform import PatternCheckContext from tvm.script import ir as I @@ -1348,5 +1349,77 @@ def local_func( tvm.ir.assert_structural_equal(Expected, After) +def test_concat(): + @R.function + def func(x: R.Tensor((10,), "float32"), y: R.Tensor((10,), "float32")): + R.func_attr({"global_symbol": "main"}) + with R.dataflow(): + lv = R.abs(x) + lv1 = R.abs(y) + lv2 = R.concat([lv, lv1]) + gv = R.nn.relu(lv2) + R.output(gv) + return gv + + @I.ir_module + class Expected1: + @R.function(private=True) + def fused_relax_abs_relax_abs_relax_concat( + x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((20,), dtype="float32"): + R.func_attr({"Composite": "x.concat_abs_abs", "Primitive": True}) + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = R.abs(x) + lv1: R.Tensor((10,), dtype="float32") = R.abs(y) + gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), axis=0) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((20,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor( + (20,), dtype="float32" + ) = Expected1.fused_relax_abs_relax_abs_relax_concat(x, y) + gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv) + R.output(gv) + return gv + + mod = tvm.IRModule({"main": func}) + inp = is_tuple([is_op("relax.abs")(wildcard()), is_op("relax.abs")(wildcard())]) + pat_clip = is_op("relax.concat")(inp) + + check(mod, [("x.concat_abs_abs", pat_clip)], Expected1) + + @I.ir_module + class Expected2: + @R.function(private=True) + def fused_relax_concat( + lv: R.Tensor((10,), dtype="float32"), lv1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((20,), dtype="float32"): + R.func_attr({"Composite": "x.concat", "Primitive": True}) + with R.dataflow(): + gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), axis=0) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((20,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = R.abs(x) + lv1: R.Tensor((10,), dtype="float32") = R.abs(y) + lv_1: R.Tensor((20,), dtype="float32") = Expected2.fused_relax_concat(lv, lv1) + gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv_1) + R.output(gv) + return gv + + pat_clip = is_op("relax.concat")(wildcard()) + check(mod, [("x.concat", pat_clip)], Expected2) + + if __name__ == "__main__": pytest.main([__file__]) From 5263e503324dce27afe86a4f7787e54fd7182f95 Mon Sep 17 00:00:00 2001 From: xlchen <792512955@qq.com> Date: Fri, 25 Jul 2025 16:29:20 +0800 Subject: [PATCH 2/2] fix lint --- src/relax/transform/fuse_ops.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 211430985ac3..434a7e7653a5 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -435,8 +435,7 @@ class FunctionCreator : public ExprMutator { CheckDefAndUpdateParam(tup_arg); ICHECK(GetStructInfoAs(tup_arg) == nullptr); } - } - else { + } else { CheckDefAndUpdateParam(arg); } }