Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/target/source/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ namespace codegen {
namespace intrin {
using tir::FLowerIntrinsic;

struct MetalWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.metal.simd_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.metal.simd_shuffle_up");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.metal.simd_shuffle_down");
}
}
};

template <typename T>
static PrimExpr DispatchMetalShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> metal_args{{call->args[1], call->args[2]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), metal_args);
}

TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

Expand Down Expand Up @@ -95,6 +117,37 @@ TVM_REGISTER_OP("tir.cosh")

TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);

TVM_REGISTER_OP("tir.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);

// Register low-level builtin ops.
TVM_REGISTER_OP("tir.metal.simd_shuffle")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.metal.simd_shuffle_up")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_up")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.metal.simd_shuffle_down")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_down")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace intrin
} // namespace codegen
} // namespace tvm
35 changes: 24 additions & 11 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// The mask for this reducer, as this reducer may sit inside
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
{
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
Optional<Buffer> mask_buffer;
if (need_warp_shuffle_mask_) {
mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices));
// Push the buffer description. Later this will have an
// allocation built for it.
local_bufs.push_back(mask_buffer);
local_bufs.push_back(mask_buffer.value());
}

// Emit reductions within a warp.
Expand Down Expand Up @@ -698,9 +699,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}

// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) {
PrimExpr WarpShuffle(const Op& op, Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0};
PrimExpr mask = BufferLoad(mask_buffer, indices);
PrimExpr mask;
if (mask_buffer.defined()) {
mask = BufferLoad(mask_buffer.value(), indices);
} else {
mask = IntImm(DataType::Int(32), 0);
}
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
Expand All @@ -709,11 +716,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// Check if we can use warp level reduction.
//
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal).
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) const {
// Only cuda target supports warp reductions.
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
return false;
}

need_warp_shuffle_mask_ = target_->kind->name != "metal";

// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
Expand Down Expand Up @@ -745,7 +756,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// whether reduce_extent and group_extent are valid for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else { // target_->kind->name == "cuda"
} else {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
Expand All @@ -769,6 +780,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
int warp_size_{1};
// The maximum number of threads of the device. "-1" denotes unknown.
int max_num_threads_{-1};
// A boolean indicating if the target supports warp-level masking.
bool need_warp_shuffle_mask_;

// surrounding scope of thread extent.
std::vector<const AttrStmtNode*> thread_extents_;
Expand Down
103 changes: 103 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,5 +702,108 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
B_1[threadIdx_y] = red_result_1[threadIdx_y]


class TestMetalNoMask(BaseCompare):
@T.prim_func
def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "metal",
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
cross_thread_B = T.allocate([1], "float32", "local")
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
A_1 = T.Buffer((256,), data=A.data)
T.tvm_thread_allreduce(
T.uint32(1),
A_1[threadIdx_y * 128 + threadIdx_x],
T.bool(True),
cross_thread_B_1[0],
threadIdx_x,
)
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = cross_thread_B_1[0]

@T.prim_func
def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "metal",
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
red_result = T.allocate([2], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.allocate([1], "float32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1 = T.allocate([1], "float32", "local")
t0_1 = T.allocate([1], "float32", "local")
red_buf_staging = T.allocate([8], "float32", "shared")
red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
A_1 = T.Buffer((256,), data=A.data)
red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x]
t0_2 = T.Buffer((1,), data=t0_1, scope="local")
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 16, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 8, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 4, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 2, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 1, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared")
if threadIdx_x % 32 == 0:
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 4:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 2, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_result_1[threadIdx_y]


if __name__ == "__main__":
tvm.testing.main()