44Test the piecewise compilation with a simple model so that we
55can exactly calculate the expected output and side effects.
66"""
7+
78import pytest
89import torch
910from torch import nn
10- from torch .library import Library
1111
1212from vllm .compilation .counter import compilation_counter
1313from vllm .compilation .decorators import support_torch_compile
1414from vllm .config import (CompilationConfig , CompilationLevel , CUDAGraphMode ,
1515 VllmConfig , set_current_vllm_config )
1616from vllm .envs import VLLM_USE_V1
1717from vllm .forward_context import BatchDescriptor , set_forward_context
18- from vllm .utils import direct_register_custom_op
19-
20- global_counter = 0
21-
22- # create a library to hold the custom op
23- silly_lib = Library ("silly" , "FRAGMENT" ) # noqa
24-
25-
26- def silly_attention (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
27- out : torch .Tensor ) -> None :
28- global global_counter
29- global_counter += 1
30- print (f"{ global_counter = } " )
31- out .copy_ (q )
32- out [0 ] += 1
33-
34-
35- def silly_attention_fake (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
36- out : torch .Tensor ) -> None :
37- return
38-
3918
40- direct_register_custom_op (
41- op_name = "attention" ,
42- op_func = silly_attention ,
43- mutates_args = ["out" ],
44- fake_impl = silly_attention_fake ,
45- target_lib = silly_lib ,
46- )
19+ # This import automatically registers `torch.ops.silly.attention`
20+ from ..silly_attention import get_global_counter , reset_global_counter
4721
4822
4923@support_torch_compile
@@ -59,8 +33,7 @@ def __init__(self,
5933 def forward (self , x : torch .Tensor ) -> torch .Tensor :
6034 """
6135 Overall effect:
62- x += 1
63- x[0] += 2
36+ x = 3 * x + 19
6437 global_counter += 2
6538 """
6639 x = x + 1
@@ -78,6 +51,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7851
7952
8053@pytest .mark .parametrize ("use_inductor" , [True , False ])
54+ @torch .inference_mode ()
8155def test_simple_piecewise_compile (use_inductor ):
8256 assert VLLM_USE_V1
8357
@@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
12195 model (torch .randn (1 ).cuda ())
12296
12397 input = torch .zeros (2 ).cuda ()
124- global global_counter
125- global_counter = 0
98+ reset_global_counter ()
12699 with set_forward_context (
127100 None ,
128101 vllm_config = vllm_config ,
129102 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE ,
130103 batch_descriptor = BatchDescriptor (num_tokens = 2 , )):
131104 output = model (input )
132- assert global_counter == 2
133- assert torch .allclose (output .cpu (), torch .tensor ([3. , 1. ]))
105+ assert get_global_counter () == 2
106+ assert torch .allclose (output .cpu (), torch .tensor ([19.0 , 19.0 ]))
0 commit comments