Skip to content

Commit ecea5b1

Browse files
Lunderbergjunrushao
authored andcommitted
[UnitTest][NVPTX] Avoid cascading failures from CUDA postproc (apache#15136)
Prior to this commit, the tests in `test_tir_transform_inject_ptx_async_copy.py` registered the `"tvm_callback_cuda_postproc"` function during pytest collection, and used a global variable to disable its functionality outside of the tests in this file. This had two major issues. First, if any other test also installs a postproc function, these postproc function required by the NVPTX tests would be overwritten. Second, if one of the NTPTX tests fails, the global variable controlling the postproc function would not be reset, causing any subsequent CUDA-related tests to also fail. This commit updates these NVPTX tests to conditionally install the postproc function, to de-register it after the test instead of disabling its functionality, and to de-register it regardless of the test result. This issue was initially found when debugging apache#15103, when a failure in `test_tir_transform_inject_ptx_async_copy.py::test_cp_async_in_if_then_else` caused failures in 32 unrelated tests ([CI link](https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-gpu/detail/PR-15103/7/tests)).
1 parent 5de7598 commit ecea5b1

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import numpy as np
17+
1818
import tvm
1919
import tvm.testing
2020
from tvm.script import tir as T
2121

22+
import pytest
23+
import numpy as np
24+
2225

2326
def count_cp_async(stmt):
2427
num_alloc = [0]
@@ -351,36 +354,54 @@ def test_inject_async_copy_shared_dyn():
351354
"""
352355

353356

354-
generated_code = ""
355-
support_async = True
357+
@pytest.fixture
358+
def postproc_if_missing_async_support():
359+
arch = tvm.contrib.nvcc.get_target_compute_version()
360+
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
361+
support_async = major >= 8
362+
363+
func_name = "tvm_callback_cuda_postproc"
364+
prev_postproc = tvm.get_global_func(func_name, allow_missing=True)
365+
366+
# Store the generated code prior to the post-processing. This
367+
# way, even though the generated code doesn't compile on platforms
368+
# that do not support async, the comparison against an expected
369+
# output can still be performed. We cannot use
370+
# `mod.get_source()`, as that contains the source after all
371+
# post-processing.
372+
original_code = None
373+
374+
def get_original_code():
375+
nonlocal original_code
376+
return original_code
377+
378+
@tvm.register_func(func_name, override=True)
379+
def tvm_callback_cuda_postproc(code, _):
380+
nonlocal original_code
381+
original_code = code
382+
if support_async:
383+
return code
384+
else:
385+
ret = []
386+
for line in code.split("\n"):
387+
ret.append(line)
388+
ret.append("\n")
389+
if line.startswith('extern "C" __global__') and line.endswith("{"):
390+
break
391+
ret.append("}")
392+
return "".join(ret)
356393

394+
yield get_original_code
357395

358-
@tvm.register_func
359-
def tvm_callback_cuda_postproc(code, _):
360-
global generated_code
361-
global support_async
362-
generated_code = code
363-
# return a dummy code so that device < sm80 could build correctly
364-
if not support_async:
365-
ret = ""
366-
for line in code.split("\n"):
367-
ret += line + "\n"
368-
if line.startswith('extern "C" __global__'):
369-
break
370-
ret += "}"
371-
return ret
372-
return code
396+
# Restore previous postproc func to avoid impacting other tests
397+
if prev_postproc is None:
398+
tvm._ffi.registry.remove_global_func(func_name)
399+
else:
400+
tvm.register_func(func_name, prev_postproc, override=True)
373401

374402

375403
@tvm.testing.requires_cuda
376-
def test_cp_async_in_if_then_else():
377-
global support_async
378-
arch = tvm.contrib.nvcc.get_target_compute_version()
379-
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
380-
if major < 8:
381-
# At least sm80 is required
382-
support_async = False
383-
404+
def test_cp_async_in_if_then_else(postproc_if_missing_async_support):
384405
@T.prim_func
385406
def simple_compute(
386407
A: T.Buffer((16, 14), "float32"),
@@ -422,22 +443,12 @@ def simple_compute(
422443
mod = tvm.IRModule.from_expr(simple_compute)
423444
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
424445
tvm.build(mod, target="cuda")
446+
generated_code = postproc_if_missing_async_support()
425447
assert generated_code == expected_cuda_script
426448

427-
if not support_async:
428-
# avoid return dummy code to other tests
429-
support_async = True
430-
431449

432450
@tvm.testing.requires_cuda
433-
def test_vectorize_cp_async_in_if_then_else():
434-
global support_async
435-
arch = tvm.contrib.nvcc.get_target_compute_version()
436-
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
437-
if major < 8:
438-
# At least sm80 is required
439-
support_async = False
440-
451+
def test_vectorize_cp_async_in_if_then_else(postproc_if_missing_async_support):
441452
@T.prim_func
442453
def complex_compute(
443454
A: T.Buffer((2, 16, 16, 1280), "float16"),
@@ -887,16 +898,10 @@ def complex_compute(
887898
mod = tvm.IRModule.from_expr(complex_compute)
888899
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
889900
tvm.build(mod, target="cuda")
901+
generated_code = postproc_if_missing_async_support()
890902
# generated_code must contain " setp.ne.b32 p, %0, 0;"
891903
assert "setp.ne.b32" in generated_code
892904

893-
if not support_async:
894-
# avoid return dummy code to other tests
895-
support_async = True
896-
897905

898906
if __name__ == "__main__":
899-
test_inject_async_copy()
900-
test_inject_async_copy_shared_dyn()
901-
test_cp_async_in_if_then_else()
902-
test_vectorize_cp_async_in_if_then_else()
907+
tvm.testing.main()

0 commit comments

Comments
 (0)