Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ jobs:
run: |
${{ env.TRITON_TEST_CMD }} --interpreter

- name: Run triton kernels tests
if: matrix.suite == 'rest'
run: |
${{ env.TRITON_TEST_CMD }} --triton-kernels

# FIXME: make sure new tutorials are added to one of the groups (scaled_dot, rest, tutorial-faX)
- name: Select tutorials to run (scaled_dot)
if: matrix.suite == 'scaled_dot'
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ jobs:
cd ${{ env.NEW_WORKSPACE }}
${{ env.TRITON_TEST_CMD }} --core

- name: Run triton kernels tests
run: |
.venv\Scripts\activate.ps1
Invoke-BatchFile "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
cd ${{ env.NEW_WORKSPACE }}
${{ env.TRITON_TEST_CMD }} --triton-kernels

- name: Run interpreter tests
run: |
.venv\Scripts\activate.ps1
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/pip-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ jobs:
cd ${{ env.NEW_WORKSPACE }}
${{ env.TRITON_TEST_CMD }} --interpreter

- name: Run triton kernels tests
run: |
.venv\Scripts\activate.ps1
Invoke-BatchFile "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
cd ${{ env.NEW_WORKSPACE }}
${{ env.TRITON_TEST_CMD }} --triton-kernels

- name: Run tutorials
run: |
.venv\Scripts\activate.ps1
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pip-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ jobs:
run: |
${{ env.TRITON_TEST_CMD }} --interpreter --skip-pip-install

- name: Run triton kernels tests
run: |
${{ env.TRITON_TEST_CMD }} --triton-kernels --skip-pip-install

- name: Run Tutorials
run: |
${{ env.TRITON_TEST_CMD }} --tutorial --skip-pip-install
Expand Down
6 changes: 6 additions & 0 deletions python/triton/language/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ def is_hip_cdna3():
def is_hip_cdna4():
target = current_target()
return target is not None and target.arch == "gfx950"


@constexpr_function
def is_xpu():
target = current_target()
return target is not None and target.backend == "xpu"
17 changes: 9 additions & 8 deletions python/triton_kernels/tests/test_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype:


@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
def test_mxfp4_rounding_cases(dst_dtype):
def test_mxfp4_rounding_cases(dst_dtype, device):
dst_dtype = dtype_str_to_torch(dst_dtype)
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).cuda().bfloat16().view(1, -1, 1)
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).to(device).bfloat16().view(1, -1, 1)
quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1)
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
assert dequant.flatten().tolist() == [6, 0, 0, 0.5, 1.0, 1.0, 1.0, 1.5], f"{dequant=}"
Expand All @@ -33,8 +33,8 @@ def test_mxfp4_rounding_cases(dst_dtype):

@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
def test_mxfp_quant_dequant(src_dtype, dst_dtype):
if "float8" in src_dtype and torch.cuda.get_device_capability()[0] < 9:
def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
if "float8" in src_dtype and (device == "cuda" and torch.cuda.get_device_capability()[0] < 9):
pytest.skip("Float8 not tested on A100")
limit_range = src_dtype == "float8_e5m2" and dst_dtype == "float16"

Expand All @@ -48,14 +48,14 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype):
max_val = 128

# These are all the valid mxfp4 positive values.
pos_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, max_val], device="cuda", dtype=dst_dtype)
pos_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, max_val], device=device, dtype=dst_dtype)
neg_vals = -pos_vals
k_dim = torch.cat([pos_vals, neg_vals])
k_dim = k_dim.reshape([k_dim.shape[0], 1])

# We pick power of 2 scales since both the scales and their inverse only require exponent bits to be exactly
# represented. This means we can store the scales exactly in the e8m0 format.
powers = torch.arange(-8, 8, device="cuda", dtype=dst_dtype)
powers = torch.arange(-8, 8, device=device, dtype=dst_dtype)
scales = 2**powers
scales = scales.reshape([1, powers.shape[0]])
weight = k_dim * scales
Expand Down Expand Up @@ -85,13 +85,14 @@ def test_mxfp_casting(
quant_dtype: str,
dequant_dtype: str,
rounding_mode: DequantScaleRoundingMode,
device,
):
if "float8" in quant_dtype and torch.cuda.get_device_capability()[0] < 9:
if "float8" in quant_dtype and (device == "cuda" and torch.cuda.get_device_capability()[0] < 9):
pytest.skip("Float8 not tested on A100")
quant_torch_type = dtype_str_to_torch(quant_dtype)
dequant_torch_type = dtype_str_to_torch(dequant_dtype)
# Generate random input tensor that is contiguous once axis is the last dimension
x = torch.randn(shape, device="cuda", dtype=dequant_torch_type)
x = torch.randn(shape, device=device, dtype=dequant_torch_type)

# Quantize and check equivalence
quant, scale = downcast_to_mxfp(x, quant_torch_type, axis, DEQUANT_SCALE_ROUNDING_MODE=rounding_mode)
Expand Down
8 changes: 4 additions & 4 deletions python/triton_kernels/tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from triton_kernels.testing import assert_equal


def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"):
def init_data(n_tokens, n_expts_tot, device, dtype=torch.float16):
logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device, requires_grad=True)
return logits

Expand All @@ -32,7 +32,7 @@ def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_
ref_logits = tri_logits.clone().detach().requires_grad_(True)

if use_expt_indx:
rand_idx = lambda: torch.randperm(n_expts_tot, device="cuda", dtype=torch.int64)
rand_idx = lambda: torch.randperm(n_expts_tot, device=device, dtype=torch.int64)
tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens_pad)])
tri_expt_indx, _ = torch.sort(tri_expt_indx, dim=1)
tri_expt_indx[n_tokens_raw:] = -99999 # should not be used
Expand Down Expand Up @@ -76,11 +76,11 @@ def _assert_indx_equal(ref, tri):
assert_close(ref_logits.grad[:n_tokens_raw], tri_logits.grad[:n_tokens_raw])


def bench_routing():
def bench_routing(device):
import triton.profiler as proton
n_tokens = 8192
n_expts_tot, n_expts_act = 128, 4
tri_logits = init_data(n_tokens, n_expts_tot)
tri_logits = init_data(n_tokens, n_expts_tot, device)
proton.start("routing")
proton.activate()
for i in range(100):
Expand Down
2 changes: 1 addition & 1 deletion python/triton_kernels/tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_op(M, N, limit, device, alpha=0.5):
# initialize expert data
n_expts_tot = 6
n_expts_act = 2
logits = init_routing_data(M, n_expts_tot).detach()
logits = init_routing_data(M, n_expts_tot, device).detach()
routing_data, _, _ = routing_torch(logits, n_expts_act)
n_tokens = routing_data.expt_hist.sum()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from triton._internal_testing import is_cuda
import torch
from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout

Expand All @@ -17,6 +18,7 @@
(3, 2, 36),
],
)
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA")
def test_mxfp4_scale_roundtrip(shape):
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
layout = BlackwellMXScaleLayout(x.shape)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from triton._internal_testing import is_cuda
from triton._internal_testing import is_cuda, is_xpu
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
Expand All @@ -19,6 +19,7 @@
@pytest.mark.parametrize("trans", [False, True])
@pytest.mark.parametrize("mx_axis", [0, 1])
@pytest.mark.parametrize("mma_version", [2, 3])
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA")
def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
if trans:
Expand All @@ -33,6 +34,7 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
@pytest.mark.parametrize("mx_axis", [0, 1])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA")
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
Expand Down Expand Up @@ -71,8 +73,9 @@ def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m
tl.store(Y + offs_y, y)


@pytest.mark.skipif(not is_cuda(), reason="Only supported on cuda")
@pytest.mark.skipif(not cuda_capability_geq(9), reason="Only supported for capability >= 9")
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on cuda")
@pytest.mark.skipif(not is_cuda() and not is_xpu(), reason="Only supported on cuda")
@pytest.mark.skipif(is_cuda() and not cuda_capability_geq(9), reason="Only supported for capability >= 9")
def test_upcast_mxfp4_to_bf16():
mx_axis = 0
num_warps = 4
Expand Down
2 changes: 1 addition & 1 deletion python/triton_kernels/triton_kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(ctx, a, alpha, precision_config, routing_data):
# optimization hyperparameters
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
num_warps = 4
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {}
# launch semi-persistent kernel
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
num_sms = target_info.num_sms()
Expand Down
7 changes: 6 additions & 1 deletion python/triton_kernels/triton_kernels/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
is_hip,
is_hip_cdna3,
is_hip_cdna4,
is_xpu,
)

__all__ = [
Expand All @@ -19,6 +20,7 @@
"is_hip",
"is_hip_cdna3",
"is_hip_cdna4",
"is_xpu",
"num_sms",
]

Expand Down Expand Up @@ -51,4 +53,7 @@ def has_native_mxfp():


def num_sms():
return torch.cuda.get_device_properties(0).multi_processor_count
if is_cuda():
return torch.cuda.get_device_properties(0).multi_processor_count
if is_xpu():
return torch.xpu.get_device_properties(0).max_compute_units
3 changes: 3 additions & 0 deletions scripts/skiplist/a770/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/arl-h/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/arl-s/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/conda/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/default/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/lts/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/mtl/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
3 changes: 3 additions & 0 deletions scripts/skiplist/xe2/triton_kernels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_matmul.py::test_op
tests/test_matmul.py::test_fused_act
tests/test_matmul.py::test_zero_reduction_dim
35 changes: 28 additions & 7 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ HELP="\
Example usage: ./test-triton.sh [TEST]... [OPTION]...

TEST:
--unit default
--core default
--tutorial default
--microbench default
--minicore part of core
--mxfp part of core
--scaled-dot part of core
--unit default
--core default
--tutorial default
--microbench default
--triton-kernels default
--minicore part of core
--mxfp part of core
--scaled-dot part of core
--interpreter
--benchmarks
--softmax
Expand Down Expand Up @@ -64,6 +65,7 @@ TEST_BENCHMARK_FLEX_ATTENTION=false
TEST_INSTRUMENTATION=false
TEST_INDUCTOR=false
TEST_SGLANG=false
TEST_TRITON_KERNELS=false
VENV=false
TRITON_TEST_REPORTS=false
TRITON_TEST_WARNING_REPORTS=false
Expand Down Expand Up @@ -179,6 +181,11 @@ while (( $# != 0 )); do
TEST_DEFAULT=false
shift
;;
--triton-kernels)
TEST_TRITON_KERNELS=true
TEST_DEFAULT=false
shift
;;
--venv)
VENV=true
shift
Expand Down Expand Up @@ -234,6 +241,7 @@ if [ "$TEST_DEFAULT" = true ]; then
TEST_CORE=true
TEST_TUTORIAL=true
TEST_MICRO_BENCHMARKS=true
TEST_TRITON_KERNELS=true
fi

if [ "$VENV" = true ]; then
Expand Down Expand Up @@ -562,6 +570,16 @@ run_sglang_tests() {
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-4} test/srt/test_triton_attention_kernels.py
}

run_triton_kernels_tests() {
echo "***************************************************"
echo "****** Running Triton Kernels tests ******"
echo "***************************************************"
cd $TRITON_PROJ/python/triton_kernels/tests

TRITON_TEST_SUITE=triton_kernels \
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu .
}

test_triton() {
if [ "$TEST_UNIT" = true ]; then
run_unit_tests
Expand Down Expand Up @@ -615,6 +633,9 @@ test_triton() {
if [ "$TEST_SGLANG" == true ]; then
run_sglang_tests
fi
if [ "$TEST_TRITON_KERNELS" == true ]; then
run_triton_kernels_tests
fi
}

install_deps
Expand Down
Loading