Skip to content

Commit 3554439

Browse files
alexm-redhatmgoinywang96
committed
[V1] VLM - preprocessor hashing
Signed-off-by: Roger Wang <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent d1c2e15 commit 3554439

File tree

8 files changed

+220
-32
lines changed

8 files changed

+220
-32
lines changed

examples/offline_inference_vision_language.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
For most models, the prompt format should follow corresponding examples
66
on HuggingFace model repository.
77
"""
8+
import random
9+
810
from transformers import AutoTokenizer
911

1012
from vllm import LLM, SamplingParams
@@ -23,7 +25,11 @@ def run_llava(question: str, modality: str):
2325

2426
prompt = f"USER: <image>\n{question}\nASSISTANT:"
2527

26-
llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
28+
llm = LLM(
29+
model="llava-hf/llava-1.5-7b-hf",
30+
max_model_len=4096,
31+
# TODO: Fix this!
32+
mm_cache_preprocessor=args.mm_cache_preprocessor)
2733
stop_token_ids = None
2834
return llm, prompt, stop_token_ids
2935

@@ -524,14 +530,35 @@ def main(args):
524530

525531
else:
526532
# Batch inference
527-
inputs = [{
528-
"prompt": prompt,
529-
"multi_modal_data": {
530-
modality: data
531-
},
532-
} for _ in range(args.num_prompts)]
533-
533+
if args.image_repeat_ratio is not None:
534+
assert (args.image_repeat_ratio <= 1.0
535+
and args.image_repeat_ratio >= 0)
536+
no_yes = [0, 1]
537+
probs = [1.0 - args.image_repeat_ratio, args.image_repeat_ratio]
538+
539+
inputs = []
540+
cur_image = data
541+
for i in range(args.num_prompts):
542+
if args.image_repeat_ratio is not None:
543+
res = random.choices(no_yes, probs)[0]
544+
if res == 0:
545+
# No repeat => Modify one pixel
546+
cur_image = cur_image.copy()
547+
new_val = (i // 256 // 256, i // 256, i % 256)
548+
cur_image.putpixel((0, 0), new_val)
549+
550+
inputs.append({
551+
"prompt": prompt,
552+
"multi_modal_data": {
553+
modality: cur_image
554+
}
555+
})
556+
557+
import time
558+
start_time = time.time()
534559
outputs = llm.generate(inputs, sampling_params=sampling_params)
560+
elapsed_time = time.time() - start_time
561+
print("-- generate time = {}".format(elapsed_time))
535562

536563
for o in outputs:
537564
generated_text = o.outputs[0].text
@@ -561,5 +588,18 @@ def main(args):
561588
type=int,
562589
default=16,
563590
help='Number of frames to extract from the video.')
591+
592+
parser.add_argument(
593+
'--image-repeat-ratio',
594+
type=float,
595+
default=None,
596+
help='Simulates the hit-ratio for multi-modal preprocessor cache'
597+
' (if enabled)')
598+
599+
parser.add_argument(
600+
'--mm-cache-preprocessor',
601+
action='store_true',
602+
help='If True, enable caching of multi-modal preprocessor/mapper.')
603+
564604
args = parser.parse_args()
565605
main(args)

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class ModelConfig:
133133
HuggingFace config.
134134
mm_processor_kwargs: Arguments to be forwarded to the model's processor
135135
for multi-modal data, e.g., image processor.
136+
mm_cache_preprocessor: If True, enable caching of multi-modal
137+
preprocessor/mapper.
136138
override_neuron_config: Initialize non default neuron config or
137139
override default neuron config that are specific to Neuron devices,
138140
this argument will be used to configure the neuron config that
@@ -171,6 +173,7 @@ def __init__(
171173
config_format: ConfigFormat = ConfigFormat.AUTO,
172174
hf_overrides: Optional[HfOverrides] = None,
173175
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
176+
mm_cache_preprocessor: bool = False,
174177
override_neuron_config: Optional[Dict[str, Any]] = None,
175178
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
176179
self.model = model
@@ -237,6 +240,7 @@ def __init__(
237240
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
238241
self.use_async_output_proc = use_async_output_proc
239242
self.mm_processor_kwargs = mm_processor_kwargs
243+
self.mm_cache_preprocessor = mm_cache_preprocessor
240244

241245
# Set enforce_eager to False if the value is unset.
242246
if self.enforce_eager is None:
@@ -2610,9 +2614,10 @@ def __str__(self):
26102614
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
26112615
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
26122616
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
2617+
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
26132618
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
2614-
f"pooler_config={self.model_config.pooler_config!r},"
2615-
f" compilation_config={self.compilation_config!r}")
2619+
f"pooler_config={self.model_config.pooler_config!r}, "
2620+
f"compilation_config={self.compilation_config!r}")
26162621

26172622

26182623
_current_vllm_config: Optional[VllmConfig] = None

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class EngineArgs:
143143
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
144144
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
145145
mm_processor_kwargs: Optional[Dict[str, Any]] = None
146+
mm_cache_preprocessor: bool = False
146147
enable_lora: bool = False
147148
enable_lora_bias: bool = False
148149
max_loras: int = 1
@@ -590,6 +591,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
590591
type=json.loads,
591592
help=('Overrides for the multimodal input mapping/processing, '
592593
'e.g., image processor. For example: {"num_crops": 4}.'))
594+
parser.add_argument(
595+
'--mm-cache-preprocessor',
596+
action='store_true',
597+
help='If True, enable caching of multi-modal preprocessor/mapper.')
593598

594599
# LoRA related configs
595600
parser.add_argument('--enable-lora',
@@ -962,6 +967,7 @@ def create_model_config(self) -> ModelConfig:
962967
use_async_output_proc=not self.disable_async_output_proc,
963968
config_format=self.config_format,
964969
mm_processor_kwargs=self.mm_processor_kwargs,
970+
mm_cache_preprocessor=self.mm_cache_preprocessor,
965971
override_neuron_config=self.override_neuron_config,
966972
override_pooler_config=self.override_pooler_config,
967973
)

vllm/v1/engine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class EngineCoreRequest:
3535
# always be tokenized?
3636
prompt: Optional[str]
3737
prompt_token_ids: List[int]
38-
mm_inputs: Optional[List[MultiModalKwargs]]
38+
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
39+
mm_hashes: Optional[List[Optional[str]]]
3940
mm_placeholders: Optional[MultiModalPlaceholderDict]
4041
sampling_params: SamplingParams
4142
eos_token_id: Optional[int]

vllm/v1/engine/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
2020
EngineCoreProfile, EngineCoreRequest,
2121
EngineCoreRequestType)
22-
from vllm.v1.engine.mm_input_mapper import MMInputMapper
22+
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
2323
from vllm.v1.executor.gpu_executor import GPUExecutor
2424
from vllm.v1.request import Request, RequestStatus
2525
from vllm.v1.serial_utils import PickleEncoder
@@ -55,16 +55,15 @@ def __init__(
5555
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
5656
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
5757

58-
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
59-
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
60-
6158
# Setup scheduler.
6259
self.scheduler = Scheduler(vllm_config.scheduler_config,
6360
vllm_config.cache_config,
6461
vllm_config.lora_config)
6562

6663
self._last_logging_time = time.time()
6764

65+
self.mm_input_mapper_server = MMInputMapperServer()
66+
6867
def _initialize_kv_caches(self,
6968
cache_config: CacheConfig) -> Tuple[int, int]:
7069
start = time.time()
@@ -88,7 +87,14 @@ def _initialize_kv_caches(self,
8887

8988
def add_request(self, request: EngineCoreRequest):
9089
"""Add request to the scheduler."""
90+
91+
# Add doc
92+
if request.mm_hashes is not None:
93+
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
94+
request.mm_inputs, request.mm_hashes)
95+
9196
req = Request.from_engine_core_request(request)
97+
9298
self.scheduler.add_request(req)
9399

94100
def abort_requests(self, request_ids: List[str]):

vllm/v1/engine/mm_input_mapper.py

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from typing import Any, Dict, List, Optional
22

3+
import PIL
4+
from blake3 import blake3
5+
36
from vllm.config import ModelConfig
47
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
58
MultiModalKwargs, MultiModalRegistry)
9+
from vllm.v1.utils import LRUDictCache
10+
11+
# Both Client and Server must use the same cache size
12+
MM_CACHE_SIZE = 128
613

714

8-
class MMInputMapper:
15+
class MMInputMapperClient:
916

1017
def __init__(
1118
self,
@@ -18,23 +25,115 @@ def __init__(
1825
model_config)
1926
self.mm_registry.init_mm_limits_per_prompt(model_config)
2027

28+
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
29+
30+
# Set to None to disable (TODO: Disable!)
31+
self.mm_debug_cache_hit_ratio_steps = 32
32+
self.mm_cache_hits = 0
33+
self.mm_cache_misses = 0
34+
35+
def cache_hit_ratio(self, steps) -> float:
36+
total_steps = self.mm_cache_hits + self.mm_cache_misses
37+
38+
if total_steps > 0 and total_steps % steps == 0:
39+
print("[debug] MMInputMapper: cache_hit_ratio = {}".format(
40+
self.mm_cache_hits / total_steps))
41+
2142
def process_inputs(
2243
self,
2344
mm_data: MultiModalDataDict,
45+
mm_hashes: Optional[List[str]],
2446
mm_processor_kwargs: Optional[Dict[str, Any]],
2547
) -> List[MultiModalKwargs]:
2648
image_inputs = mm_data["image"]
2749
if not isinstance(image_inputs, list):
2850
image_inputs = [image_inputs]
2951

52+
use_hash = mm_hashes is not None
53+
if use_hash:
54+
assert len(image_inputs) == len(mm_hashes) # Sanity
55+
3056
# Process each image input separately so that later we can schedule
3157
# them in a fine-grained manner.
32-
mm_inputs: List[MultiModalKwargs] = []
33-
num_images = len(image_inputs)
34-
for i in range(num_images):
35-
mm_input = self.multi_modal_input_mapper(
36-
{"image": image_inputs[i]},
37-
mm_processor_kwargs=mm_processor_kwargs,
38-
)
39-
mm_inputs.append(mm_input)
40-
return mm_inputs
58+
# Utilize caching (if enabled)
59+
ret_hashes = [] if use_hash else None
60+
ret_inputs: List[MultiModalKwargs] = []
61+
for i in range(len(image_inputs)):
62+
if self.mm_debug_cache_hit_ratio_steps is not None:
63+
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
64+
65+
if use_hash:
66+
mm_hash = mm_hashes[i]
67+
mm_input = self.mm_cache.get(mm_hash)
68+
else:
69+
mm_hash = None
70+
mm_input = None
71+
72+
if mm_input is None:
73+
self.mm_cache_misses += 1
74+
mm_input = self.multi_modal_input_mapper(
75+
{"image": [image_inputs[i]]},
76+
mm_processor_kwargs=mm_processor_kwargs,
77+
)
78+
79+
if use_hash:
80+
self.mm_cache.put(mm_hash, mm_input)
81+
else:
82+
self.mm_cache_hits += 1
83+
mm_input = None # Avoids sending mm_input to Server
84+
85+
if use_hash:
86+
ret_hashes.append(mm_hash)
87+
ret_inputs.append(mm_input)
88+
89+
return ret_inputs, ret_hashes
90+
91+
92+
class MMInputMapperServer:
93+
94+
def __init__(self, ):
95+
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
96+
97+
def process_inputs(
98+
self,
99+
mm_inputs: List[Optional[MultiModalKwargs]],
100+
mm_hashes: List[Optional[str]],
101+
) -> List[MultiModalKwargs]:
102+
assert len(mm_inputs) == len(mm_hashes)
103+
104+
full_mm_inputs = []
105+
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
106+
if mm_input is None:
107+
mm_input = self.mm_cache.get(mm_hash)
108+
assert mm_input is not None
109+
else:
110+
self.mm_cache.put(mm_hash, mm_input)
111+
112+
full_mm_inputs.append(mm_input)
113+
114+
return full_mm_inputs
115+
116+
117+
class MMHasher:
118+
119+
def __init__(self):
120+
pass
121+
122+
def hash(self, mm_data: MultiModalDataDict) -> List[str]:
123+
image_inputs = mm_data["image"]
124+
if not isinstance(image_inputs, list):
125+
image_inputs = [image_inputs]
126+
127+
ret = []
128+
for image in image_inputs:
129+
assert isinstance(image, PIL.Image.Image)
130+
131+
# Convert image to bytes
132+
bytes = image.tobytes()
133+
134+
# Hash image bytes
135+
hasher = blake3()
136+
hasher.update(bytes)
137+
ret.append(hasher.hexdigest())
138+
139+
return ret

0 commit comments

Comments
 (0)