Skip to content

Commit f8cb85e

Browse files
apatkesumitd2
authored andcommitted
[Core] Adding Priority Scheduling (vllm-project#5958)
Signed-off-by: Sumit Dubey <[email protected]>
1 parent 72f244d commit f8cb85e

File tree

6 files changed

+410
-8
lines changed

6 files changed

+410
-8
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""Benchmark offline prioritization."""
2+
import argparse
3+
import json
4+
import random
5+
import time
6+
from typing import List, Optional, Tuple
7+
8+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
9+
10+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11+
12+
13+
def sample_requests(
14+
dataset_path: str,
15+
num_requests: int,
16+
tokenizer: PreTrainedTokenizerBase,
17+
fixed_output_len: Optional[int],
18+
) -> List[Tuple[str, int, int]]:
19+
if fixed_output_len is not None and fixed_output_len < 4:
20+
raise ValueError("output_len too small")
21+
22+
# Load the dataset.
23+
with open(dataset_path) as f:
24+
dataset = json.load(f)
25+
# Filter out the conversations with less than 2 turns.
26+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
27+
# Only keep the first two turns of each conversation.
28+
dataset = [(data["conversations"][0]["value"],
29+
data["conversations"][1]["value"]) for data in dataset]
30+
31+
# Shuffle the dataset.
32+
random.shuffle(dataset)
33+
34+
# Filter out sequences that are too long or too short
35+
filtered_dataset: List[Tuple[str, int, int]] = []
36+
for i in range(len(dataset)):
37+
if len(filtered_dataset) == num_requests:
38+
break
39+
40+
# Tokenize the prompts and completions.
41+
prompt = dataset[i][0]
42+
prompt_token_ids = tokenizer(prompt).input_ids
43+
completion = dataset[i][1]
44+
completion_token_ids = tokenizer(completion).input_ids
45+
prompt_len = len(prompt_token_ids)
46+
output_len = len(completion_token_ids
47+
) if fixed_output_len is None else fixed_output_len
48+
if prompt_len < 4 or output_len < 4:
49+
# Prune too short sequences.
50+
continue
51+
if prompt_len > 1024 or prompt_len + output_len > 2048:
52+
# Prune too long sequences.
53+
continue
54+
55+
#Select a equi-probable random priority
56+
priority = 0 if random.random() < 0.5 else 1
57+
58+
filtered_dataset.append((prompt, prompt_len, output_len, priority))
59+
60+
return filtered_dataset
61+
62+
63+
def run_vllm(
64+
requests: List[Tuple[str, int, int]],
65+
model: str,
66+
tokenizer: str,
67+
quantization: Optional[str],
68+
tensor_parallel_size: int,
69+
seed: int,
70+
n: int,
71+
use_beam_search: bool,
72+
trust_remote_code: bool,
73+
dtype: str,
74+
max_model_len: Optional[int],
75+
enforce_eager: bool,
76+
kv_cache_dtype: str,
77+
quantization_param_path: Optional[str],
78+
device: str,
79+
enable_prefix_caching: bool,
80+
enable_chunked_prefill: bool,
81+
max_num_batched_tokens: int,
82+
gpu_memory_utilization: float = 0.9,
83+
download_dir: Optional[str] = None,
84+
) -> float:
85+
from vllm import LLM, SamplingParams
86+
llm = LLM(
87+
model=model,
88+
tokenizer=tokenizer,
89+
quantization=quantization,
90+
tensor_parallel_size=tensor_parallel_size,
91+
seed=seed,
92+
trust_remote_code=trust_remote_code,
93+
dtype=dtype,
94+
max_model_len=max_model_len,
95+
gpu_memory_utilization=gpu_memory_utilization,
96+
enforce_eager=enforce_eager,
97+
kv_cache_dtype=kv_cache_dtype,
98+
quantization_param_path=quantization_param_path,
99+
device=device,
100+
enable_prefix_caching=enable_prefix_caching,
101+
download_dir=download_dir,
102+
enable_chunked_prefill=enable_chunked_prefill,
103+
max_num_batched_tokens=max_num_batched_tokens,
104+
disable_log_stats=False,
105+
)
106+
107+
# Add the requests to the engine.
108+
prompts = []
109+
sampling_params = []
110+
priority = []
111+
for prompt, _, output_len, _priority in requests:
112+
prompts.append(prompt)
113+
priority.append(_priority)
114+
sampling_params.append(
115+
SamplingParams(
116+
n=n,
117+
temperature=0.0 if use_beam_search else 1.0,
118+
top_p=1.0,
119+
use_beam_search=use_beam_search,
120+
ignore_eos=True,
121+
max_tokens=output_len,
122+
))
123+
124+
start = time.perf_counter()
125+
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
126+
end = time.perf_counter()
127+
return end - start
128+
129+
130+
def main(args: argparse.Namespace):
131+
print(args)
132+
random.seed(args.seed)
133+
134+
# Sample the requests.
135+
tokenizer = AutoTokenizer.from_pretrained(
136+
args.tokenizer, trust_remote_code=args.trust_remote_code)
137+
if args.dataset is None:
138+
# Synthesize a prompt with the given input length.
139+
prompt = "hi" * (args.input_len - 1)
140+
requests = [(prompt, args.input_len, args.output_len)
141+
for _ in range(args.num_prompts)]
142+
else:
143+
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
144+
args.output_len)
145+
146+
if args.backend == "vllm":
147+
elapsed_time = run_vllm(
148+
requests, args.model, args.tokenizer, args.quantization,
149+
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
150+
args.trust_remote_code, args.dtype, args.max_model_len,
151+
args.enforce_eager, args.kv_cache_dtype,
152+
args.quantization_param_path, args.device,
153+
args.enable_prefix_caching, args.enable_chunked_prefill,
154+
args.max_num_batched_tokens, args.gpu_memory_utilization,
155+
args.download_dir)
156+
else:
157+
raise ValueError(f"Unknown backend: {args.backend}")
158+
total_num_tokens = sum(prompt_len + output_len
159+
for _, prompt_len, output_len, priority in requests)
160+
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
161+
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
162+
163+
# Output JSON results if specified
164+
if args.output_json:
165+
results = {
166+
"elapsed_time": elapsed_time,
167+
"num_requests": len(requests),
168+
"total_num_tokens": total_num_tokens,
169+
"requests_per_second": len(requests) / elapsed_time,
170+
"tokens_per_second": total_num_tokens / elapsed_time,
171+
}
172+
with open(args.output_json, "w") as f:
173+
json.dump(results, f, indent=4)
174+
175+
176+
if __name__ == "__main__":
177+
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
178+
parser.add_argument("--backend",
179+
type=str,
180+
choices=["vllm", "hf", "mii"],
181+
default="vllm")
182+
parser.add_argument("--dataset",
183+
type=str,
184+
default=None,
185+
help="Path to the dataset.")
186+
parser.add_argument("--input-len",
187+
type=int,
188+
default=None,
189+
help="Input prompt length for each request")
190+
parser.add_argument("--output-len",
191+
type=int,
192+
default=None,
193+
help="Output length for each request. Overrides the "
194+
"output length from the dataset.")
195+
parser.add_argument("--model", type=str, default="facebook/opt-125m")
196+
parser.add_argument("--tokenizer", type=str, default=None)
197+
parser.add_argument('--quantization',
198+
'-q',
199+
choices=[*QUANTIZATION_METHODS, None],
200+
default=None)
201+
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
202+
parser.add_argument("--n",
203+
type=int,
204+
default=1,
205+
help="Number of generated sequences per prompt.")
206+
parser.add_argument("--use-beam-search", action="store_true")
207+
parser.add_argument("--num-prompts",
208+
type=int,
209+
default=200,
210+
help="Number of prompts to process.")
211+
parser.add_argument("--seed", type=int, default=0)
212+
parser.add_argument('--trust-remote-code',
213+
action='store_true',
214+
help='trust remote code from huggingface')
215+
parser.add_argument(
216+
'--max-model-len',
217+
type=int,
218+
default=None,
219+
help='Maximum length of a sequence (including prompt and output). '
220+
'If None, will be derived from the model.')
221+
parser.add_argument(
222+
'--dtype',
223+
type=str,
224+
default='auto',
225+
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
226+
help='data type for model weights and activations. '
227+
'The "auto" option will use FP16 precision '
228+
'for FP32 and FP16 models, and BF16 precision '
229+
'for BF16 models.')
230+
parser.add_argument('--gpu-memory-utilization',
231+
type=float,
232+
default=0.9,
233+
help='the fraction of GPU memory to be used for '
234+
'the model executor, which can range from 0 to 1.'
235+
'If unspecified, will use the default value of 0.9.')
236+
parser.add_argument("--enforce-eager",
237+
action="store_true",
238+
help="enforce eager execution")
239+
parser.add_argument(
240+
'--kv-cache-dtype',
241+
type=str,
242+
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
243+
default="auto",
244+
help='Data type for kv cache storage. If "auto", will use model '
245+
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
246+
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
247+
parser.add_argument(
248+
'--quantization-param-path',
249+
type=str,
250+
default=None,
251+
help='Path to the JSON file containing the KV cache scaling factors. '
252+
'This should generally be supplied, when KV cache dtype is FP8. '
253+
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
254+
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
255+
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
256+
'instead supported for common inference criteria.')
257+
parser.add_argument(
258+
"--device",
259+
type=str,
260+
default="cuda",
261+
choices=["cuda", "cpu"],
262+
help='device type for vLLM execution, supporting CUDA and CPU.')
263+
parser.add_argument(
264+
"--enable-prefix-caching",
265+
action='store_true',
266+
help="enable automatic prefix caching for vLLM backend.")
267+
parser.add_argument("--enable-chunked-prefill",
268+
action='store_true',
269+
help="enable chunked prefill for vLLM backend.")
270+
parser.add_argument('--max-num-batched-tokens',
271+
type=int,
272+
default=None,
273+
help='maximum number of batched tokens per '
274+
'iteration')
275+
parser.add_argument('--download-dir',
276+
type=str,
277+
default=None,
278+
help='directory to download and load the weights, '
279+
'default to the default cache dir of huggingface')
280+
parser.add_argument(
281+
'--output-json',
282+
type=str,
283+
default=None,
284+
help='Path to save the throughput results in JSON format.')
285+
286+
args = parser.parse_args()
287+
if args.tokenizer is None:
288+
args.tokenizer = args.model
289+
if args.dataset is None:
290+
assert args.input_len is not None
291+
assert args.output_len is not None
292+
else:
293+
assert args.input_len is None
294+
295+
main(args)

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ class SchedulerConfig:
961961
workers instead of an entire data. It should be enabled only
962962
when SPMD worker architecture is enabled. I.e.,
963963
VLLM_USE_RAY_SPMD_WORKER=1
964-
964+
policy: The scheduling policy to use. "fcfs" (default) or "priority".
965965
"""
966966

967967
def __init__(self,
@@ -977,7 +977,8 @@ def __init__(self,
977977
preemption_mode: Optional[str] = None,
978978
num_scheduler_steps: int = 1,
979979
multi_step_stream_outputs: bool = False,
980-
send_delta_data: bool = False) -> None:
980+
send_delta_data: bool = False,
981+
policy: str = "fcfs") -> None:
981982
if max_num_batched_tokens is None:
982983
if enable_chunked_prefill:
983984
# It is the values that have the best balance between ITL
@@ -1019,6 +1020,7 @@ def __init__(self,
10191020
self.num_scheduler_steps = num_scheduler_steps
10201021
self.multi_step_stream_outputs = multi_step_stream_outputs
10211022
self.send_delta_data = send_delta_data
1023+
self.policy = policy
10221024
self._verify_args()
10231025

10241026
def _verify_args(self) -> None:

0 commit comments

Comments
 (0)