-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Fix][Relax] Fix potential out-of-bounds access in TupleRewriterNode
#18120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| v = embedded_qkv_tuple[2] | ||
| q_embed = embedded_qkv_tuple[0] | ||
| k_embed = embedded_qkv_tuple[1] | ||
| v = embedded_qkv_tuple[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this does not impact the structural equality in TVM's IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review. However, before() was rewritten as
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
gv8: R.Tensor((4096,), dtype="float32") = gv[0] <<<<<======== order: [0] -> [1] -> [2] here
^
gv9: R.Tensor((4096,), dtype="float32") = gv[1]
gv10: R.Tensor((4096,), dtype="float32") = gv[2]
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (gv8, gv9, gv10, kv_cache), sinfo_args=(R.Tensor((4096,)),))
return attention
while the original version of expect() was rewritten as:
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2] <<<<<======== the original order: [2] -> [0] -> [1]
^
q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
return attention
I just run the test again - if we don't switch the order of these lines the test will fail due to a tvm.ir.assert_structural_equal(expected, after) assertion failure.
Please let me know if this explanation sounds reasonable to you, or if further clarification is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask what version of TVM you use? In the 0.21.dev0 version on my machine, the rewritten code is:
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = gv[2]
q_embed: R.Tensor((4096,), dtype="float32") = gv[0]
k_embed: R.Tensor((4096,), dtype="float32") = gv[1]
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
return attentionSo it has structural equality to the expected. And I have debug the backward C++ code and has not reproduce the problem you mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the investigation. I'm using the main branch, but the test still fails after switching to the 0.21.dev0 tag. The code you posted is indeed the one generated by expected(), and it is the same as the second snippet I shared earlier, as you may have noticed :).
I was wondering if you could reproduce the test failure mentioned in this PR by running <python> -m pytest -s -v tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements? If you can't, then that would be a bit unexpected.
But if you can, from what I understand, the issue here may still be that before() produces a natural variable definition order of 0 -> 1 -> 2, whileexpected() generates 2 -> 1 -> 0.
Please let me know if I misunderstood your concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your explanation, but in my version, the following three IRModule are shown for before, after and expected.
# from tvm.script import relax as R
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.split(qkv, indices_or_sections=3, axis=0)
q: R.Tensor((4096,), dtype="float32") = qkv_tuple[0]
k: R.Tensor((4096,), dtype="float32") = qkv_tuple[1]
v: R.Tensor((4096,), dtype="float32") = qkv_tuple[2]
q_embed: R.Tensor((4096,), dtype="float32") = R.call_pure_packed("rotary_embedding", (q,), sinfo_args=(R.Tensor((4096,), dtype="float32"),))
k_embed: R.Tensor((4096,), dtype="float32") = R.call_pure_packed("rotary_embedding", (k,), sinfo_args=(R.Tensor((4096,), dtype="float32"),))
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
return attention
# from tvm.script import relax as R
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = gv[2]
q_embed: R.Tensor((4096,), dtype="float32") = gv[0]
k_embed: R.Tensor((4096,), dtype="float32") = gv[1]
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
return attention
# from tvm.script import relax as R
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, out_dtype="void")
embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.call_pure_packed("split_rotary_embedding", (qkv,), sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2]
q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),))
return attentionResult of running pytest:
(tvm-build-venv) [user1@localhost tvm-git]$ python3 -m pytest -s -v tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements
enabled targets: llvm; cuda; nvptx
pytest marker:
====================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.11.13, pytest-8.4.0, pluggy-1.6.0 -- /home/user1/miniconda3/envs/tvm-build-venv/bin/python3
cachedir: .pytest_cache
rootdir: /home/user1/Github/tvm-git
configfile: pyproject.toml
collected 1 item
tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements PASSED
======================================================================================================= 1 passed in 0.14s =======================================================================================================There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the detailed information. This seems a bit odd - I'll need some more time to investigate further. I'll reply to this thread if I find anything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got some time this week to reinvestigate this issue, and it turns out I initially missed the root cause, which is an undefined behavior caused by out-of-bounds iterator access.
The test environment I used earlier was macOS, where the test consistently fails. On Linux, as you pointed out, the test passes.
The root issue is that indices.begin() + (j + 1) can exceed indices.end(), leading to undefined behavior. In this case, indices.size() is 3, but info_vec.size() is 7, and j can be 5, causing indices.begin() + (j + 1) to go out of bounds. In the worst case, I think std::all_of may never terminate.
Verified all tests from the Rewriter PR, passing on both macOS and Linux.
Hopefully this explains it.
tests/python/relax/test_dataflow_rewriter.pyTupleRewriterNode
ConvolutedDog
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reinvestigating this and the thorough analysis. I’ve also observed the same behavior on MacOS. Out-of-bounds access fails due to some compiler's stricter memory checks. The fix works perfectly across both MacOS and Linux.
But lint failed because a space was missing (ref: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-lint/detail/PR-18120/2/pipeline/62).
Thanks a lot for verifying on macOS. I've fixed the formatting issue and let's wait for the CI results. :) |
|
cc @tqchen |
apache#18120) * Root cause * Update
The unit test
test_rewrite_of_implicit_tuple_with_three_elements()has been failing due to a mismatch between the test IR and the expected pattern structure.After some investigations, the failure is not caused by an issue in the rewriter implementation, but rather by a mismatch in the return structure of the
before()pattern function. The pattern defined in the test expects the final expression to be a tuple, butbefore()directly uses unpacked values, causing the pattern match to fail.To fix the test, we can simply add a tuple to match the pattern. Also, the rewritten function generates a natural indexing order (
0 -> 1 -> 2), but the test has a different order (2 -> 0 -> 1) so we should update the indexing order in the test accordingly.To reproduce:
Thanks,