Skip to content

Conversation

@Phoslight
Copy link
Contributor

The unit test test_rewrite_of_implicit_tuple_with_three_elements() has been failing due to a mismatch between the test IR and the expected pattern structure.

After some investigations, the failure is not caused by an issue in the rewriter implementation, but rather by a mismatch in the return structure of the before() pattern function. The pattern defined in the test expects the final expression to be a tuple, but before() directly uses unpacked values, causing the pattern match to fail.

To fix the test, we can simply add a tuple to match the pattern. Also, the rewritten function generates a natural indexing order (0 -> 1 -> 2), but the test has a different order (2 -> 0 -> 1) so we should update the indexing order in the test accordingly.

To reproduce:

<python> -m pytest -s -v tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements

Thanks,

v = embedded_qkv_tuple[2]
q_embed = embedded_qkv_tuple[0]
k_embed = embedded_qkv_tuple[1]
v = embedded_qkv_tuple[2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this does not impact the structural equality in TVM's IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review. However, before() was rewritten as

@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
    qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
    gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
    gv8: R.Tensor((4096,), dtype="float32") = gv[0]      <<<<<========  order: [0] -> [1] -> [2] here
                                                 ^
    gv9: R.Tensor((4096,), dtype="float32") = gv[1]
    gv10: R.Tensor((4096,), dtype="float32") = gv[2]
    attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (gv8, gv9, gv10, kv_cache), sinfo_args=(R.Tensor((4096,)),))
    return attention

while the original version of expect() was rewritten as:

@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
    qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
    embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
    v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2]    <<<<<======== the original order: [2] -> [0] -> [1]
                                                               ^
    q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
    k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
    attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
    return attention

I just run the test again - if we don't switch the order of these lines the test will fail due to a tvm.ir.assert_structural_equal(expected, after) assertion failure.

Please let me know if this explanation sounds reasonable to you, or if further clarification is needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what version of TVM you use? In the 0.21.dev0 version on my machine, the rewritten code is:

@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"))) 
v: R.Tensor((4096,), dtype="float32") = gv[2] 
q_embed: R.Tensor((4096,), dtype="float32") = gv[0] 
k_embed: R.Tensor((4096,), dtype="float32") = gv[1] 
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),)) 
return attention

So it has structural equality to the expected. And I have debug the backward C++ code and has not reproduce the problem you mentioned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the investigation. I'm using the main branch, but the test still fails after switching to the 0.21.dev0 tag. The code you posted is indeed the one generated by expected(), and it is the same as the second snippet I shared earlier, as you may have noticed :).

I was wondering if you could reproduce the test failure mentioned in this PR by running <python> -m pytest -s -v tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements? If you can't, then that would be a bit unexpected.

But if you can, from what I understand, the issue here may still be that before() produces a natural variable definition order of 0 -> 1 -> 2, whileexpected() generates 2 -> 1 -> 0.

Please let me know if I misunderstood your concern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your explanation, but in my version, the following three IRModule are shown for before, after and expected.

# from tvm.script import relax as R

@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
    qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
    qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.split(qkv, indices_or_sections=3, axis=0)
    q: R.Tensor((4096,), dtype="float32") = qkv_tuple[0]
    k: R.Tensor((4096,), dtype="float32") = qkv_tuple[1]
    v: R.Tensor((4096,), dtype="float32") = qkv_tuple[2]
    q_embed: R.Tensor((4096,), dtype="float32") = R.call_pure_packed("rotary_embedding", (q,), sinfo_args=(R.Tensor((4096,), dtype="float32"),))
    k_embed: R.Tensor((4096,), dtype="float32") = R.call_pure_packed("rotary_embedding", (k,), sinfo_args=(R.Tensor((4096,), dtype="float32"),))
    attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
    return attention

# from tvm.script import relax as R

@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
    qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
    gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
    v: R.Tensor((4096,), dtype="float32") = gv[2]
    q_embed: R.Tensor((4096,), dtype="float32") = gv[0]
    k_embed: R.Tensor((4096,), dtype="float32") = gv[1]
    attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
    return attention

# from tvm.script import relax as R

@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
    qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
    embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
    v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2]
    q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
    k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
    attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
    return attention

Result of running pytest:

(tvm-build-venv) [user1@localhost tvm-git]$ python3 -m pytest -s -v tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements
enabled targets: llvm; cuda; nvptx
pytest marker: 
====================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.11.13, pytest-8.4.0, pluggy-1.6.0 -- /home/user1/miniconda3/envs/tvm-build-venv/bin/python3
cachedir: .pytest_cache
rootdir: /home/user1/Github/tvm-git
configfile: pyproject.toml
collected 1 item                                                                                                                                                                                                                

tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements PASSED

======================================================================================================= 1 passed in 0.14s =======================================================================================================

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed information. This seems a bit odd - I'll need some more time to investigate further. I'll reply to this thread if I find anything else.

Copy link
Contributor Author

@Phoslight Phoslight Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got some time this week to reinvestigate this issue, and it turns out I initially missed the root cause, which is an undefined behavior caused by out-of-bounds iterator access.

The test environment I used earlier was macOS, where the test consistently fails. On Linux, as you pointed out, the test passes.

The root issue is that indices.begin() + (j + 1) can exceed indices.end(), leading to undefined behavior. In this case, indices.size() is 3, but info_vec.size() is 7, and j can be 5, causing indices.begin() + (j + 1) to go out of bounds. In the worst case, I think std::all_of may never terminate.

Verified all tests from the Rewriter PR, passing on both macOS and Linux.

Hopefully this explains it.

@Phoslight Phoslight changed the title [Test] Fix tuple pattern in unittest tests/python/relax/test_dataflow_rewriter.py [Fix][Relax] Fix potential out-of-bounds access in TupleRewriterNode Jul 16, 2025
Copy link
Contributor

@ConvolutedDog ConvolutedDog left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reinvestigating this and the thorough analysis. I’ve also observed the same behavior on MacOS. Out-of-bounds access fails due to some compiler's stricter memory checks. The fix works perfectly across both MacOS and Linux.

But lint failed because a space was missing (ref: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-lint/detail/PR-18120/2/pipeline/62).

@Phoslight
Copy link
Contributor Author

Thanks for reinvestigating this and the thorough analysis. I’ve also observed the same behavior on MacOS. Out-of-bounds access fails due to some compiler's stricter memory checks. The fix works perfectly across both MacOS and Linux.

But lint failed because a space was missing (ref: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-lint/detail/PR-18120/2/pipeline/62).

Thanks a lot for verifying on macOS. I've fixed the formatting issue and let's wait for the CI results. :)

@ConvolutedDog
Copy link
Contributor

cc @tqchen

@Hzfengsy Hzfengsy merged commit 5e12a5c into apache:main Jul 19, 2025
10 checks passed
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants