@@ -28,22 +28,32 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]):
2828
2929 def transform_module (self , mod : IRModule , _ctx : tvm .transform .PassContext ) -> IRModule :
3030 """Entrypoint"""
31- if str (self .target .kind ) not in ["cuda" , "vulkan" , "metal" ]:
32- # Only enable GPU sampling for CUDA, Vulkan, and Metal .
31+ if str (self .target .kind ) not in ["cuda" , "vulkan" , "metal" , "webgpu" ]:
32+ # Only enable GPU sampling for CUDA, Vulkan, Metal, and WebGPU .
3333 return mod
3434
3535 bb = relax .BlockBuilder (mod )
36- gv_names = [
37- gv .name_hint
38- for gv in [
39- _attach_multinomial_sampling_func (bb ),
40- _attach_argsort_func (bb ),
41- _attach_sample_with_top_p (bb ),
42- _attach_take_probs_func (bb ),
43- _attach_batch_verifier (bb ),
44- _attach_renormalize_by_top_p (bb , self .target ),
36+ if str (self .target .kind ) == "webgpu" :
37+ # Only attach functions that do not contain i8s for WebGPU
38+ gv_names = [
39+ gv .name_hint
40+ for gv in [
41+ _attach_argsort_func (bb ),
42+ _attach_sample_with_top_p (bb ),
43+ ]
44+ ]
45+ else :
46+ gv_names = [
47+ gv .name_hint
48+ for gv in [
49+ _attach_multinomial_sampling_func (bb ),
50+ _attach_argsort_func (bb ),
51+ _attach_sample_with_top_p (bb ),
52+ _attach_take_probs_func (bb ),
53+ _attach_batch_verifier (bb ),
54+ _attach_renormalize_by_top_p (bb , self .target ),
55+ ]
4556 ]
46- ]
4757
4858 mod = bb .finalize ()
4959 for gv_name in gv_names :
0 commit comments