Skip to content

Commit c1b22d5

Browse files
committed
Test case rollback and fixes.
1 parent bc5071b commit c1b22d5

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

src/relax/transform/legalize_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class LegalizeMutator : public ExprMutator {
161161
return NullOpt;
162162
}
163163

164-
Expr UpdateVDeviceOutStructInfo(Expr expr, Call& visited_call) {
164+
Expr UpdateVDeviceOutStructInfo(Expr expr, const Call& visited_call) {
165165
static const auto& infer_struct_info_map = Op::GetAttrMap<FInferStructInfo>("FInferStructInfo");
166166
static const Op& call_tir_op = Op::Get("relax.call_tir");
167167
auto* op_node = visited_call->op.as<OpNode>();

src/relax/transform/optimize_to_vdevice_for_scope_change.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ class CollectConsumerDetails : public ExprVisitor {
102102
public:
103103
using ExprVisitor::VisitExpr_;
104104

105-
Map<Expr, Array<Expr>> Collect(
106-
const IRModule& mod, Function func, const Target& target) {
105+
Map<Expr, Array<Expr>> Collect(const IRModule& mod, Function func, const Target& target) {
107106
mod_ = mod;
108107
target_ = target;
109108
VisitExpr(func->body);
@@ -159,7 +158,6 @@ class CollectConsumerDetails : public ExprVisitor {
159158
}
160159

161160
private:
162-
163161
/* Map of each Var consumption by a call node */
164162
Map<Expr, Array<Expr>> consumers;
165163
Map<Expr, Expr> arg_to_binding;
@@ -172,7 +170,8 @@ namespace transform {
172170
Pass OptimizeToVDeviceForScopeChange() {
173171
auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
174172
/* here Target doesn't matter as the consumers we use only to find multiple consumers */
175-
auto consumers = CollectConsumerDetails().Collect(mod, Downcast<Function>(func), Target("opencl"));
173+
auto consumers =
174+
CollectConsumerDetails().Collect(mod, Downcast<Function>(func), Target("opencl"));
176175
auto [pattern, rewriter] = CreatePatterns(consumers);
177176
return RewriteCall(pattern, rewriter, func);
178177
};

src/relax/transform/remove_redundant_assignments.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class RemoveRedundantAssignments : public ExprMutator {
4444
public:
4545
using ExprMutator::VisitExpr_;
4646

47-
IRModule Run(IRModule& mod) {
47+
IRModule Run(const IRModule& mod) {
4848
mod_ = mod;
4949
for (const auto& [gv, func] : mod_->functions) {
5050
if (func->IsInstance<relax::FunctionNode>()) {

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ def concatenate(
10491049
),
10501050
T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"),
10511051
):
1052-
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
1052+
T.func_attr({"tir.noalias": T.bool(True)})
10531053
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)):
10541054
with T.block("T_concat"):
10551055
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -1069,7 +1069,7 @@ def transpose2(
10691069
rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"),
10701070
T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"),
10711071
):
1072-
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
1072+
T.func_attr({"tir.noalias": T.bool(True)})
10731073
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)):
10741074
with T.block("T_transpose"):
10751075
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -1117,7 +1117,7 @@ def fused_concatenate_transpose2(
11171117
(T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"
11181118
),
11191119
):
1120-
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
1120+
T.func_attr({"tir.noalias": T.bool(True)})
11211121
T_concat_handle_intermediate = T.alloc_buffer(
11221122
(T.int64(2), T.int64(4), T.int64(64), T.int64(64))
11231123
)
@@ -1242,7 +1242,7 @@ def reshape(
12421242
A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
12431243
T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"),
12441244
):
1245-
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
1245+
T.func_attr({"tir.noalias": T.bool(True)})
12461246
# with T.block("root"):
12471247
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)):
12481248
with T.block("T_reshape"):
@@ -1307,7 +1307,7 @@ def fused_reshape(
13071307
(T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"
13081308
),
13091309
):
1310-
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
1310+
T.func_attr({"tir.noalias": T.bool(True)})
13111311
# with T.block("root"):
13121312
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)):
13131313
with T.block("T_reshape"):

0 commit comments

Comments
 (0)