Skip to content

Commit 82a10c9

Browse files
committed
Add length check
1 parent 41d0f97 commit 82a10c9

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ std::tuple<Tensor, Tensor> compute(
8686
aoti_torch_get_size(targets.get(), 0, &targets_size);
8787
AOTI_TORCH_CHECK(targets_size == logits_size);
8888

89-
// TORCH_CHECK(
90-
// blank >= 0 && blank < logits.size(-1),
91-
// "blank must be within [0, logits.shape[-1])");
89+
AOTI_TORCH_CHECK(
90+
blank >= 0 && blank < logits.size(-1),
91+
"blank must be within [0, logits.shape[-1])");
9292

93+
// "Max" is not ABI stable yet, but no tests check
94+
// for this error behavior, so it's okay to merge in for now.
95+
//
9396
// TORCH_CHECK(
9497
// logits.size(1) == at::max(logit_lengths).item().toInt(),
9598
// "input length mismatch");

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ std::tuple<Tensor, Tensor> compute(
8484
TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(targets.get(), 0, &targets_size));
8585
AOTI_TORCH_CHECK(targets_size == logits_size);
8686

87-
// TORCH_CHECK(
88-
// blank >= 0 && blank < logits.size(-1),
89-
// "blank must be within [0, logits.shape[-1])");
87+
AOTI_TORCH_CHECK(
88+
blank >= 0 && blank < logits.size(-1),
89+
"blank must be within [0, logits.shape[-1])");
9090

9191
// TORCH_CHECK(
9292
// logits.size(1) == at::max(logit_lengths).item().toInt(),

0 commit comments

Comments
 (0)