Skip to content

Commit 73a62f6

Browse files
authored
[TIR] Preserve AllocateNode::annotations (#15242)
Prior to this commit, some lowering passes would erroneously strip out the annotations from `Allocate` nodes. This commit updates these passes to preserve the annotations where present.
1 parent 2f7c097 commit 73a62f6

File tree

6 files changed

+13
-9
lines changed

6 files changed

+13
-9
lines changed

src/tir/transforms/inject_double_buffer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ class DoubleBufferInjector : public StmtExprMutator {
119119
Array<PrimExpr> new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)};
120120
ICHECK(entry.loop != nullptr);
121121
auto& alloc_nest = loop_allocs_[entry.loop];
122-
alloc_nest.emplace_back(
123-
Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0)));
122+
alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, op->condition,
123+
Evaluate(0), op->annotations));
124124
Stmt body = op->body;
125125
if (auto ptr = body.as<DeclBufferNode>()) {
126126
auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride);

src/tir/transforms/ir_utils.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ class IRConvertSSA final : public StmtExprMutator {
335335
ScopedRedefine redefine(this, v);
336336
Stmt stmt = StmtExprMutator::VisitStmt_(op);
337337
op = stmt.as<AllocateNode>();
338-
return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body);
338+
return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body,
339+
op->annotations);
339340
} else {
340341
defined_.insert(v.get());
341342
return StmtExprMutator::VisitStmt_(op);

src/tir/transforms/lower_custom_datatypes.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
9797
allocate = stmt.as<AllocateNode>();
9898

9999
return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition,
100-
allocate->body);
100+
allocate->body, allocate->annotations);
101101
} else {
102102
return StmtExprMutator::VisitStmt_(allocate);
103103
}

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop
5353
// use volatile access to shared buffer.
5454
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
5555
}
56-
return Allocate(remapped, op->dtype, op->extents, op->condition, body);
56+
return Allocate(remapped, op->dtype, op->extents, op->condition, body, op->annotations);
5757
}
5858
return StmtExprMutator::VisitStmt_(op);
5959
}

src/tir/transforms/lower_warp_memory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
249249
alloc_size = warp_group_ * factor;
250250

251251
return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)},
252-
op->condition, this->VisitStmt(op->body));
252+
op->condition, this->VisitStmt(op->body), op->annotations);
253253
}
254254

255255
protected:

src/tir/transforms/update_pointer_storage_scope.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/tir/transform.h>
3030

3131
#include <unordered_map>
32+
#include <utility>
3233

3334
#include "../../runtime/thread_storage_scope.h"
3435
#include "ir_utils.h"
@@ -59,9 +60,11 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) {
5960
}
6061

6162
Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) {
62-
auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
63-
return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition),
64-
StmtExprMutator::VisitStmt(op->body));
63+
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
64+
if (auto it = new_var_remap_.find(node->buffer_var.get()); it != new_var_remap_.end()) {
65+
node.CopyOnWrite()->buffer_var = it->second;
66+
}
67+
return std::move(node);
6568
}
6669

6770
template <typename Node>

0 commit comments

Comments
 (0)