|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | | -import numpy as np |
| 17 | + |
18 | 18 | import tvm |
19 | 19 | import tvm.testing |
20 | 20 | from tvm.script import tir as T |
21 | 21 |
|
| 22 | +import pytest |
| 23 | +import numpy as np |
| 24 | + |
22 | 25 |
|
23 | 26 | def count_cp_async(stmt): |
24 | 27 | num_alloc = [0] |
@@ -351,36 +354,54 @@ def test_inject_async_copy_shared_dyn(): |
351 | 354 | """ |
352 | 355 |
|
353 | 356 |
|
354 | | -generated_code = "" |
355 | | -support_async = True |
| 357 | +@pytest.fixture |
| 358 | +def postproc_if_missing_async_support(): |
| 359 | + arch = tvm.contrib.nvcc.get_target_compute_version() |
| 360 | + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) |
| 361 | + support_async = major >= 8 |
| 362 | + |
| 363 | + func_name = "tvm_callback_cuda_postproc" |
| 364 | + prev_postproc = tvm.get_global_func(func_name, allow_missing=True) |
| 365 | + |
| 366 | + # Store the generated code prior to the post-processing. This |
| 367 | + # way, even though the generated code doesn't compile on platforms |
| 368 | + # that do not support async, the comparison against an expected |
| 369 | + # output can still be performed. We cannot use |
| 370 | + # `mod.get_source()`, as that contains the source after all |
| 371 | + # post-processing. |
| 372 | + original_code = None |
| 373 | + |
| 374 | + def get_original_code(): |
| 375 | + nonlocal original_code |
| 376 | + return original_code |
| 377 | + |
| 378 | + @tvm.register_func(func_name, override=True) |
| 379 | + def tvm_callback_cuda_postproc(code, _): |
| 380 | + nonlocal original_code |
| 381 | + original_code = code |
| 382 | + if support_async: |
| 383 | + return code |
| 384 | + else: |
| 385 | + ret = [] |
| 386 | + for line in code.split("\n"): |
| 387 | + ret.append(line) |
| 388 | + ret.append("\n") |
| 389 | + if line.startswith('extern "C" __global__') and line.endswith("{"): |
| 390 | + break |
| 391 | + ret.append("}") |
| 392 | + return "".join(ret) |
356 | 393 |
|
| 394 | + yield get_original_code |
357 | 395 |
|
358 | | -@tvm.register_func |
359 | | -def tvm_callback_cuda_postproc(code, _): |
360 | | - global generated_code |
361 | | - global support_async |
362 | | - generated_code = code |
363 | | - # return a dummy code so that device < sm80 could build correctly |
364 | | - if not support_async: |
365 | | - ret = "" |
366 | | - for line in code.split("\n"): |
367 | | - ret += line + "\n" |
368 | | - if line.startswith('extern "C" __global__'): |
369 | | - break |
370 | | - ret += "}" |
371 | | - return ret |
372 | | - return code |
| 396 | + # Restore previous postproc func to avoid impacting other tests |
| 397 | + if prev_postproc is None: |
| 398 | + tvm._ffi.registry.remove_global_func(func_name) |
| 399 | + else: |
| 400 | + tvm.register_func(func_name, prev_postproc, override=True) |
373 | 401 |
|
374 | 402 |
|
375 | 403 | @tvm.testing.requires_cuda |
376 | | -def test_cp_async_in_if_then_else(): |
377 | | - global support_async |
378 | | - arch = tvm.contrib.nvcc.get_target_compute_version() |
379 | | - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) |
380 | | - if major < 8: |
381 | | - # At least sm80 is required |
382 | | - support_async = False |
383 | | - |
| 404 | +def test_cp_async_in_if_then_else(postproc_if_missing_async_support): |
384 | 405 | @T.prim_func |
385 | 406 | def simple_compute( |
386 | 407 | A: T.Buffer((16, 14), "float32"), |
@@ -422,22 +443,12 @@ def simple_compute( |
422 | 443 | mod = tvm.IRModule.from_expr(simple_compute) |
423 | 444 | with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): |
424 | 445 | tvm.build(mod, target="cuda") |
| 446 | + generated_code = postproc_if_missing_async_support() |
425 | 447 | assert generated_code == expected_cuda_script |
426 | 448 |
|
427 | | - if not support_async: |
428 | | - # avoid return dummy code to other tests |
429 | | - support_async = True |
430 | | - |
431 | 449 |
|
432 | 450 | @tvm.testing.requires_cuda |
433 | | -def test_vectorize_cp_async_in_if_then_else(): |
434 | | - global support_async |
435 | | - arch = tvm.contrib.nvcc.get_target_compute_version() |
436 | | - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) |
437 | | - if major < 8: |
438 | | - # At least sm80 is required |
439 | | - support_async = False |
440 | | - |
| 451 | +def test_vectorize_cp_async_in_if_then_else(postproc_if_missing_async_support): |
441 | 452 | @T.prim_func |
442 | 453 | def complex_compute( |
443 | 454 | A: T.Buffer((2, 16, 16, 1280), "float16"), |
@@ -887,16 +898,10 @@ def complex_compute( |
887 | 898 | mod = tvm.IRModule.from_expr(complex_compute) |
888 | 899 | with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): |
889 | 900 | tvm.build(mod, target="cuda") |
| 901 | + generated_code = postproc_if_missing_async_support() |
890 | 902 | # generated_code must contain " setp.ne.b32 p, %0, 0;" |
891 | 903 | assert "setp.ne.b32" in generated_code |
892 | 904 |
|
893 | | - if not support_async: |
894 | | - # avoid return dummy code to other tests |
895 | | - support_async = True |
896 | | - |
897 | 905 |
|
898 | 906 | if __name__ == "__main__": |
899 | | - test_inject_async_copy() |
900 | | - test_inject_async_copy_shared_dyn() |
901 | | - test_cp_async_in_if_then_else() |
902 | | - test_vectorize_cp_async_in_if_then_else() |
| 907 | + tvm.testing.main() |
0 commit comments