11from typing import Any , Dict , List , Optional
22
3+ import PIL
4+ from blake3 import blake3
5+
36from vllm .config import ModelConfig
47from vllm .multimodal import (MULTIMODAL_REGISTRY , MultiModalDataDict ,
58 MultiModalKwargs , MultiModalRegistry )
9+ from vllm .v1 .utils import LRUDictCache
10+
11+ # Both Client and Server must use the same cache size
12+ MM_CACHE_SIZE = 128
613
714
8- class MMInputMapper :
15+ class MMInputMapperClient :
916
1017 def __init__ (
1118 self ,
@@ -18,23 +25,115 @@ def __init__(
1825 model_config )
1926 self .mm_registry .init_mm_limits_per_prompt (model_config )
2027
28+ self .mm_cache = LRUDictCache (MM_CACHE_SIZE )
29+
30+ # Set to None to disable (TODO: Disable!)
31+ self .mm_debug_cache_hit_ratio_steps = 32
32+ self .mm_cache_hits = 0
33+ self .mm_cache_misses = 0
34+
35+ def cache_hit_ratio (self , steps ) -> float :
36+ total_steps = self .mm_cache_hits + self .mm_cache_misses
37+
38+ if total_steps > 0 and total_steps % steps == 0 :
39+ print ("[debug] MMInputMapper: cache_hit_ratio = {}" .format (
40+ self .mm_cache_hits / total_steps ))
41+
2142 def process_inputs (
2243 self ,
2344 mm_data : MultiModalDataDict ,
45+ mm_hashes : Optional [List [str ]],
2446 mm_processor_kwargs : Optional [Dict [str , Any ]],
2547 ) -> List [MultiModalKwargs ]:
2648 image_inputs = mm_data ["image" ]
2749 if not isinstance (image_inputs , list ):
2850 image_inputs = [image_inputs ]
2951
52+ use_hash = mm_hashes is not None
53+ if use_hash :
54+ assert len (image_inputs ) == len (mm_hashes ) # Sanity
55+
3056 # Process each image input separately so that later we can schedule
3157 # them in a fine-grained manner.
32- mm_inputs : List [MultiModalKwargs ] = []
33- num_images = len (image_inputs )
34- for i in range (num_images ):
35- mm_input = self .multi_modal_input_mapper (
36- {"image" : image_inputs [i ]},
37- mm_processor_kwargs = mm_processor_kwargs ,
38- )
39- mm_inputs .append (mm_input )
40- return mm_inputs
58+ # Utilize caching (if enabled)
59+ ret_hashes = [] if use_hash else None
60+ ret_inputs : List [MultiModalKwargs ] = []
61+ for i in range (len (image_inputs )):
62+ if self .mm_debug_cache_hit_ratio_steps is not None :
63+ self .cache_hit_ratio (self .mm_debug_cache_hit_ratio_steps )
64+
65+ if use_hash :
66+ mm_hash = mm_hashes [i ]
67+ mm_input = self .mm_cache .get (mm_hash )
68+ else :
69+ mm_hash = None
70+ mm_input = None
71+
72+ if mm_input is None :
73+ self .mm_cache_misses += 1
74+ mm_input = self .multi_modal_input_mapper (
75+ {"image" : [image_inputs [i ]]},
76+ mm_processor_kwargs = mm_processor_kwargs ,
77+ )
78+
79+ if use_hash :
80+ self .mm_cache .put (mm_hash , mm_input )
81+ else :
82+ self .mm_cache_hits += 1
83+ mm_input = None # Avoids sending mm_input to Server
84+
85+ if use_hash :
86+ ret_hashes .append (mm_hash )
87+ ret_inputs .append (mm_input )
88+
89+ return ret_inputs , ret_hashes
90+
91+
92+ class MMInputMapperServer :
93+
94+ def __init__ (self , ):
95+ self .mm_cache = LRUDictCache (MM_CACHE_SIZE )
96+
97+ def process_inputs (
98+ self ,
99+ mm_inputs : List [Optional [MultiModalKwargs ]],
100+ mm_hashes : List [Optional [str ]],
101+ ) -> List [MultiModalKwargs ]:
102+ assert len (mm_inputs ) == len (mm_hashes )
103+
104+ full_mm_inputs = []
105+ for mm_input , mm_hash in zip (mm_inputs , mm_hashes ):
106+ if mm_input is None :
107+ mm_input = self .mm_cache .get (mm_hash )
108+ assert mm_input is not None
109+ else :
110+ self .mm_cache .put (mm_hash , mm_input )
111+
112+ full_mm_inputs .append (mm_input )
113+
114+ return full_mm_inputs
115+
116+
117+ class MMHasher :
118+
119+ def __init__ (self ):
120+ pass
121+
122+ def hash (self , mm_data : MultiModalDataDict ) -> List [str ]:
123+ image_inputs = mm_data ["image" ]
124+ if not isinstance (image_inputs , list ):
125+ image_inputs = [image_inputs ]
126+
127+ ret = []
128+ for image in image_inputs :
129+ assert isinstance (image , PIL .Image .Image )
130+
131+ # Convert image to bytes
132+ bytes = image .tobytes ()
133+
134+ # Hash image bytes
135+ hasher = blake3 ()
136+ hasher .update (bytes )
137+ ret .append (hasher .hexdigest ())
138+
139+ return ret
0 commit comments