Skip to content
Merged
49 changes: 29 additions & 20 deletions python/test/unit/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torch
import triton.language as tl
import triton
import sys
import subprocess
import os


@pytest.mark.parametrize('cond', [True, False])
Expand All @@ -10,29 +13,35 @@
@pytest.mark.parametrize('env_var', [True, False])
@pytest.mark.parametrize('jit_flag', [True, False])
@pytest.mark.forked
def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device):
monkeypatch.setenv("TRITON_DEBUG", str(int(env_var)))
triton.knobs.refresh_knobs()
torch.zeros([1], dtype=torch.int32, device=device)

@triton.jit(debug=jit_flag)
def _kernel(COND: tl.constexpr, MASK: tl.constexpr):
tl.device_assert(COND, 'test', mask=MASK)
def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device):
"""Temporary subprocess solution due to:
https://github.com/pytorch/pytorch/issues/142135"""

is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag)

kwargs = {}
if opt_flag is not None:
kwargs["debug"] = opt_flag

if not cond and is_debug and mask is not False:
with pytest.raises(RuntimeError):
_kernel[(1, )](cond, mask, **kwargs)
getattr(torch, device).synchronize()
return

_kernel[(1, )](cond, mask, **kwargs)
getattr(torch, device).synchronize()
should_fail = not cond and is_debug and mask is not False
kernel_file = os.path.join(os.path.dirname(__file__), "test_debug_kernels.py")
mask_str = "None" if mask is None else str(mask)
opt_flag_str = "None" if opt_flag is None else str(opt_flag)

result = subprocess.run([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can set environment variables directly when calling subprocess.run, for example for TRITON_DEBUG. This way you can pass fewer parameters into test_debug_kernels.py.

sys.executable, kernel_file, "device_assert",
str(cond), mask_str, opt_flag_str,
str(jit_flag), device,
str(env_var)
], capture_output=True, text=True)

if should_fail:
abort_or_runtime_error = (
result.returncode == 1 or # RuntimeError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases do we get a RuntimeError? At first glance, I would expect all errors of SIGABRT type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is needed for other devices, then you can leave the old implementation for devices that are not XPU: if not is_xpu() and add a separate branch for us with what you wrote, perhaps this can simplify the code. This is optional because I don't know if it will result in less code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe easier to resolve merge conflicts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the old implementation would result in many redundant parts.
I agree however that RuntimeError should not be in the assert condition for xpu as only SIGABRTs are thrown for us.
Ill create a seperate branch for that - that way when the pytorch bug is fixed this test will start failing and we can then go back to the common implementation.

Copy link
Contributor Author

@dev-tomek dev-tomek Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the new commit I used if device == 'xpu' as is_xpu() fails with "Cannot re-initialize XPU in forked subprocess." due to @pytest.mark.forked

result.returncode == -6 # SIGABRT
)
assert abort_or_runtime_error, (
f"Expected runtime error or abort signal but got unexpected exit code {result.returncode}. "
f"stdout: {result.stdout}, stderr: {result.stderr}")
else:
assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. "
f"stdout: {result.stdout}, stderr: {result.stderr}")


def test_device_assert_barrier(monkeypatch, device):
Expand Down
56 changes: 56 additions & 0 deletions python/test/unit/test_debug_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Helper module containing Triton kernels for test_debug.py.
These kernels are separated so they can be called from subprocesses.
"""
import torch
import triton
import triton.language as tl
import sys
import os


def run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device):

@triton.jit(debug=jit_flag)
def _kernel(COND: tl.constexpr, MASK: tl.constexpr):
tl.device_assert(COND, 'test', mask=MASK)

kwargs = {}
if opt_flag is not None:
kwargs["debug"] = opt_flag

try:
_kernel[(1, )](cond, mask, **kwargs)
getattr(torch, device).synchronize()
return 0
except RuntimeError:
return 1
except Exception as e:
print(f"Unexpected error: {type(e).__name__}: {e}")
return 2


if __name__ == "__main__":

def parse_bool_or_none(arg_str):
if arg_str == "None":
return None
return arg_str == "True"

test_type = sys.argv[1]
if test_type == "device_assert":
cond = sys.argv[2] == "True"
mask = parse_bool_or_none(sys.argv[3])
opt_flag = parse_bool_or_none(sys.argv[4])
jit_flag = sys.argv[5] == "True"
device = sys.argv[6]
env_var = sys.argv[7] == "True"

os.environ["TRITON_DEBUG"] = str(int(env_var))
triton.knobs.refresh_knobs()
exit_code = run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device)
sys.exit(exit_code)

else:
print(f"Unknown test type: {test_type}")
sys.exit(3)
1 change: 0 additions & 1 deletion scripts/skiplist/a770/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
1 change: 0 additions & 1 deletion scripts/skiplist/arl-h/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
1 change: 0 additions & 1 deletion scripts/skiplist/arl-s/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
1 change: 0 additions & 1 deletion scripts/skiplist/default/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
1 change: 0 additions & 1 deletion scripts/skiplist/lts/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
1 change: 0 additions & 1 deletion scripts/skiplist/mtl/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
1 change: 0 additions & 1 deletion scripts/skiplist/xe2/debug.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]
Expand Down
Loading