@@ -142,20 +142,24 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig);
142142
143143class StmtSimplifier : public IRMutatorWithAnalyzer {
144144 public:
145- static Stmt Apply (Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) {
145+ static PrimFunc Apply (PrimFunc func, Analyzer* analyzer,
146+ Optional<SimplifyConfig> config_opt = NullOpt) {
146147 auto config = config_opt.value_or (AttrsWithDefaultValues<arith::SimplifyConfig>());
147148 analyzer->rewrite_simplify .SetEnabledExtensions (config->GetEnabledExtensions ());
148149
149150 std::optional<ControlFlowGraph> touch_pattern = std::nullopt ;
150151 if (config->propagate_knowns_to_prove_conditional ||
151152 config->propagate_knowns_to_simplify_expressions ) {
152- touch_pattern = ControlFlowGraph (stmt );
153+ touch_pattern = ControlFlowGraph (func-> body );
153154 }
154155
155- std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition (stmt);
156+ std::unordered_set<const VarNode*> used_in_buffer_def =
157+ CollectVarsUsedInBufferDefinition (func->body );
156158 StmtSimplifier simplifier (analyzer, config, std::move (touch_pattern),
157159 std::move (used_in_buffer_def));
158- return simplifier (std::move (stmt));
160+ simplifier.MarkBufferMapShapes (func);
161+ func.CopyOnWrite ()->body = simplifier (func->body );
162+ return func;
159163 }
160164
161165 private:
@@ -335,21 +339,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
335339} // namespace arith
336340
337341namespace tir {
338-
339- Stmt Simplify (Stmt stmt, arith::Analyzer* analyzer) {
340- return arith::StmtSimplifier::Apply (stmt, analyzer);
341- }
342-
343342namespace transform {
344343
345344Pass Simplify () {
346345 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
347346 arith::Analyzer analyzer;
348347 auto cfg = ctx->GetConfig <arith::SimplifyConfig>(" tir.Simplify" );
349348
350- auto * n = f.CopyOnWrite ();
351- n->body = arith::StmtSimplifier::Apply (std::move (n->body ), &analyzer, cfg);
352- return f;
349+ return arith::StmtSimplifier::Apply (f, &analyzer, cfg);
353350 };
354351 return CreatePrimFuncPass (pass_func, 0 , " tir.Simplify" , {});
355352}
0 commit comments