Skip to content

Commit 92247c5

Browse files
authored
[Bug] Fix moe_sum signature (#18440)
Signed-off-by: Bill Nell <[email protected]>
1 parent 0c15c2e commit 92247c5

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

csrc/moe/torch_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
1010

1111
// Calculate the result of moe by summing up the partial results
1212
// from all selected experts.
13-
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
13+
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
1414
m.impl("moe_sum", torch::kCUDA, &moe_sum);
1515

1616
// Aligning the number of tokens to be processed by each expert such

tests/kernels/moe/test_moe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
575575
opcheck(torch.ops._moe_C.moe_align_block_size,
576576
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
577577
num_tokens_post_pad))
578+
579+
580+
@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128])
581+
@pytest.mark.parametrize("topk", TOP_KS)
582+
@pytest.mark.parametrize("k", [128, 511, 1024])
583+
@pytest.mark.parametrize("dtype",
584+
[torch.float32, torch.float16, torch.bfloat16])
585+
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
586+
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
587+
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
588+
actual = torch.empty((m, k), device="cuda", dtype=dtype)
589+
590+
expected = input.sum(dim=1)
591+
torch.ops._moe_C.moe_sum(input, actual)
592+
593+
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
594+
595+
opcheck(torch.ops._moe_C.moe_sum, (input, actual))

0 commit comments

Comments
 (0)