@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
1515
1616 def _init_executor (self ) -> None :
1717 """Initialize the worker and load the model.
18-
19- If speculative decoding is enabled, we instead create the speculative
20- worker.
2118 """
22- if self .speculative_config is None :
23- self ._init_non_spec_worker ()
24- else :
25- self ._init_spec_worker ()
19+ assert self .parallel_config .world_size == 1 , (
20+ "GPUExecutor only supports single GPU." )
21+
22+ self .driver_worker = self ._create_worker ()
23+ self .driver_worker .init_device ()
24+ self .driver_worker .load_model ()
2625
2726 def _get_worker_kwargs (
2827 self ,
@@ -45,66 +44,30 @@ def _get_worker_kwargs(
4544 distributed_init_method = distributed_init_method ,
4645 lora_config = self .lora_config ,
4746 vision_language_config = self .vision_language_config ,
47+ speculative_config = self .speculative_config ,
4848 is_driver_worker = rank == 0 ,
4949 )
5050
5151 def _create_worker (self ,
5252 local_rank : int = 0 ,
5353 rank : int = 0 ,
5454 distributed_init_method : Optional [str ] = None ):
55+
56+ if self .speculative_config is None :
57+ worker_module_name = "vllm.worker.worker"
58+ worker_class_name = "Worker"
59+ else :
60+ worker_module_name = "vllm.spec_decode.spec_decode_worker"
61+ worker_class_name = "create_spec_worker"
62+
5563 wrapper = WorkerWrapperBase (
56- worker_module_name = "vllm.worker.worker" ,
57- worker_class_name = "Worker" ,
64+ worker_module_name = worker_module_name ,
65+ worker_class_name = worker_class_name ,
5866 )
5967 wrapper .init_worker (** self ._get_worker_kwargs (local_rank , rank ,
6068 distributed_init_method ))
6169 return wrapper .worker
6270
63- def _init_non_spec_worker (self ):
64- assert self .parallel_config .world_size == 1 , (
65- "GPUExecutor only supports single GPU." )
66-
67- self .driver_worker = self ._create_worker ()
68- self .driver_worker .init_device ()
69- self .driver_worker .load_model ()
70-
71- def _init_spec_worker (self ):
72- """Initialize a SpecDecodeWorker, using a draft model for proposals.
73- """
74- assert self .speculative_config is not None
75-
76- from vllm .spec_decode .spec_decode_worker import SpecDecodeWorker
77-
78- target_worker = self ._create_worker ()
79-
80- draft_worker_kwargs = self ._get_worker_kwargs ()
81- # Override draft-model specific worker args.
82- draft_worker_kwargs .update (
83- model_config = self .speculative_config .draft_model_config ,
84- parallel_config = self .speculative_config .draft_parallel_config ,
85- ngram_prompt_lookup_max = self .speculative_config .
86- ngram_prompt_lookup_max ,
87- ngram_prompt_lookup_min = self .speculative_config .
88- ngram_prompt_lookup_min ,
89- # TODO allow draft-model specific load config.
90- #load_config=self.load_config,
91- )
92-
93- spec_decode_worker = SpecDecodeWorker .create_worker (
94- scorer_worker = target_worker ,
95- draft_worker_kwargs = draft_worker_kwargs ,
96- disable_by_batch_size = self .speculative_config .
97- speculative_disable_by_batch_size ,
98- )
99-
100- assert self .parallel_config .world_size == 1 , (
101- "GPUExecutor only supports single GPU." )
102-
103- self .driver_worker = spec_decode_worker
104-
105- # Load model handled in spec decode worker.
106- self .driver_worker .init_device ()
107-
10871 def determine_num_available_blocks (self ) -> Tuple [int , int ]:
10972 """Determine the number of available KV blocks by invoking the
11073 underlying worker.
0 commit comments