-
Notifications
You must be signed in to change notification settings - Fork 75
[TEST_DEBUG] Enable test device assert #5395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
3de296d
4df011d
a1f02f2
2409c38
6dfd8ff
a7016dd
4f98c39
5fe75f1
71a3240
4b0ca6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,9 @@ | |
| import torch | ||
| import triton.language as tl | ||
| import triton | ||
| import sys | ||
| import subprocess | ||
| import os | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('cond', [True, False]) | ||
|
|
@@ -10,29 +13,35 @@ | |
| @pytest.mark.parametrize('env_var', [True, False]) | ||
| @pytest.mark.parametrize('jit_flag', [True, False]) | ||
| @pytest.mark.forked | ||
| def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device): | ||
| monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) | ||
| triton.knobs.refresh_knobs() | ||
| torch.zeros([1], dtype=torch.int32, device=device) | ||
|
|
||
| @triton.jit(debug=jit_flag) | ||
| def _kernel(COND: tl.constexpr, MASK: tl.constexpr): | ||
| tl.device_assert(COND, 'test', mask=MASK) | ||
| def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device): | ||
| """Temporary subprocess solution due to: | ||
| https://github.com/pytorch/pytorch/issues/142135""" | ||
|
|
||
| is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) | ||
|
|
||
| kwargs = {} | ||
| if opt_flag is not None: | ||
| kwargs["debug"] = opt_flag | ||
|
|
||
| if not cond and is_debug and mask is not False: | ||
| with pytest.raises(RuntimeError): | ||
| _kernel[(1, )](cond, mask, **kwargs) | ||
| getattr(torch, device).synchronize() | ||
| return | ||
|
|
||
| _kernel[(1, )](cond, mask, **kwargs) | ||
| getattr(torch, device).synchronize() | ||
| should_fail = not cond and is_debug and mask is not False | ||
| kernel_file = os.path.join(os.path.dirname(__file__), "test_debug_kernels.py") | ||
| mask_str = "None" if mask is None else str(mask) | ||
| opt_flag_str = "None" if opt_flag is None else str(opt_flag) | ||
|
|
||
| result = subprocess.run([ | ||
| sys.executable, kernel_file, "device_assert", | ||
| str(cond), mask_str, opt_flag_str, | ||
| str(jit_flag), device, | ||
| str(env_var) | ||
| ], capture_output=True, text=True) | ||
|
|
||
| if should_fail: | ||
| abort_or_runtime_error = ( | ||
| result.returncode == 1 or # RuntimeError | ||
|
||
| result.returncode == -6 # SIGABRT | ||
| ) | ||
| assert abort_or_runtime_error, ( | ||
| f"Expected runtime error or abort signal but got unexpected exit code {result.returncode}. " | ||
| f"stdout: {result.stdout}, stderr: {result.stderr}") | ||
| else: | ||
| assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. " | ||
| f"stdout: {result.stdout}, stderr: {result.stderr}") | ||
|
|
||
|
|
||
| def test_device_assert_barrier(monkeypatch, device): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| """ | ||
| Helper module containing Triton kernels for test_debug.py. | ||
| These kernels are separated so they can be called from subprocesses. | ||
| """ | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
| import sys | ||
| import os | ||
|
|
||
|
|
||
| def run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device): | ||
|
|
||
| @triton.jit(debug=jit_flag) | ||
| def _kernel(COND: tl.constexpr, MASK: tl.constexpr): | ||
| tl.device_assert(COND, 'test', mask=MASK) | ||
|
|
||
| kwargs = {} | ||
| if opt_flag is not None: | ||
| kwargs["debug"] = opt_flag | ||
|
|
||
| try: | ||
| _kernel[(1, )](cond, mask, **kwargs) | ||
| getattr(torch, device).synchronize() | ||
| return 0 | ||
| except RuntimeError: | ||
| return 1 | ||
| except Exception as e: | ||
| print(f"Unexpected error: {type(e).__name__}: {e}") | ||
| return 2 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| def parse_bool_or_none(arg_str): | ||
| if arg_str == "None": | ||
| return None | ||
| return arg_str == "True" | ||
|
|
||
| test_type = sys.argv[1] | ||
| if test_type == "device_assert": | ||
| cond = sys.argv[2] == "True" | ||
| mask = parse_bool_or_none(sys.argv[3]) | ||
| opt_flag = parse_bool_or_none(sys.argv[4]) | ||
| jit_flag = sys.argv[5] == "True" | ||
| device = sys.argv[6] | ||
| env_var = sys.argv[7] == "True" | ||
|
|
||
| os.environ["TRITON_DEBUG"] = str(int(env_var)) | ||
| triton.knobs.refresh_knobs() | ||
| exit_code = run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device) | ||
| sys.exit(exit_code) | ||
|
|
||
| else: | ||
| print(f"Unknown test type: {test_type}") | ||
| sys.exit(3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can set environment variables directly when calling
subprocess.run, for example forTRITON_DEBUG. This way you can pass fewer parameters intotest_debug_kernels.py.