Skip to content

Commit 862a731

Browse files
authored
[Python][Web] Enable sampler on WebGPU (#3365)
Enable sampler on WebGPU. Only attach necessary functions (`argsort` and `sample_with_top_p`) that do not contain unsupported types (such as i8s).
1 parent 4362567 commit 862a731

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

python/mlc_llm/compiler_pass/attach_sampler.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,32 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]):
2828

2929
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
3030
"""Entrypoint"""
31-
if str(self.target.kind) not in ["cuda", "vulkan", "metal"]:
32-
# Only enable GPU sampling for CUDA, Vulkan, and Metal.
31+
if str(self.target.kind) not in ["cuda", "vulkan", "metal", "webgpu"]:
32+
# Only enable GPU sampling for CUDA, Vulkan, Metal, and WebGPU.
3333
return mod
3434

3535
bb = relax.BlockBuilder(mod)
36-
gv_names = [
37-
gv.name_hint
38-
for gv in [
39-
_attach_multinomial_sampling_func(bb),
40-
_attach_argsort_func(bb),
41-
_attach_sample_with_top_p(bb),
42-
_attach_take_probs_func(bb),
43-
_attach_batch_verifier(bb),
44-
_attach_renormalize_by_top_p(bb, self.target),
36+
if str(self.target.kind) == "webgpu":
37+
# Only attach functions that do not contain i8s for WebGPU
38+
gv_names = [
39+
gv.name_hint
40+
for gv in [
41+
_attach_argsort_func(bb),
42+
_attach_sample_with_top_p(bb),
43+
]
44+
]
45+
else:
46+
gv_names = [
47+
gv.name_hint
48+
for gv in [
49+
_attach_multinomial_sampling_func(bb),
50+
_attach_argsort_func(bb),
51+
_attach_sample_with_top_p(bb),
52+
_attach_take_probs_func(bb),
53+
_attach_batch_verifier(bb),
54+
_attach_renormalize_by_top_p(bb, self.target),
55+
]
4556
]
46-
]
4757

4858
mod = bb.finalize()
4959
for gv_name in gv_names:

0 commit comments

Comments
 (0)