1212from vllm .utils import is_openvino_optimum_intel
1313
1414import openvino as ov
15+ from openvino import Type
1516
1617
1718def _flattenize_inputs (inputs ):
@@ -56,7 +57,8 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor:
5657
5758def patch_stateful_model (
5859 model : ov .Model ,
59- factory ):
60+ factory ,
61+ kv_cache_dtype : Type ):
6062 print ('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM' )
6163 from openvino .runtime .passes import Manager , MatcherPass , WrapType , Matcher , AnyInput , Or
6264 from openvino .runtime import opset13
@@ -128,8 +130,8 @@ def callback(m: Matcher) -> bool:
128130 real_v = mapping [v_current ]
129131 hidden_shape = real_q .get_partial_shape ()
130132 hidden_dim = hidden_shape [hidden_shape .rank .get_length () - 1 ].get_length () # TODO: What if it is a dynamic? Need to insert a ShapeOf sub-graph instead
131- k_parameter = opset13 .parameter (shape = [- 1 , - 1 , - 1 , - 1 , - 1 ], dtype = np . float32 )
132- v_parameter = opset13 .parameter (shape = [- 1 , - 1 , - 1 , - 1 ], dtype = np . float32 )
133+ k_parameter = opset13 .parameter (shape = [- 1 , - 1 , - 1 , - 1 , - 1 ], dtype = kv_cache_dtype )
134+ v_parameter = opset13 .parameter (shape = [- 1 , - 1 , - 1 , - 1 ], dtype = kv_cache_dtype )
133135 kv_parameters .append (k_parameter )
134136 kv_parameters .append (v_parameter )
135137 # TODO: The rank 4 is used in the following code, but it is not guaranteed for all models, adopt to other ranks.
@@ -274,7 +276,8 @@ def has_parameter(model, name):
274276
275277def _patch_model_with_openvino (
276278 pt_model : torch .nn .Module ,
277- model_config : ModelConfig ):
279+ model_config : ModelConfig ,
280+ kv_cache_dtype : Type ):
278281 print (' ============= PATCHING MODEL =============' )
279282 from vllm .model_executor .layers .attention .attention import Attention
280283 from openvino .frontend .pytorch import ModuleExtension
@@ -294,7 +297,15 @@ def _patch_model_with_openvino(
294297
295298 # Prepare example inputs
296299
297- kv_cache_dtype = torch .float32
300+ torch_dtype_maping = {
301+ Type .boolean : torch .bool ,
302+ Type .f32 : torch .float32 ,
303+ Type .f16 : torch .float16 ,
304+ Type .bf16 : torch .bfloat16 ,
305+ Type .i32 : torch .int32 ,
306+ Type .i64 : torch .int64
307+ }
308+ kv_cache_dtype = torch_dtype_maping [kv_cache_dtype ]
298309 num_heads = pt_model .config .num_attention_heads
299310 num_kv_heads = num_heads
300311 head_size = pt_model .config .hidden_size // num_kv_heads
@@ -423,6 +434,7 @@ def ov_sample(
423434
424435def get_model (model_config : ModelConfig ,
425436 device_config : DeviceConfig ,
437+ kv_cache_dtype : Type ,
426438 ** kwargs ) -> torch .nn .Module :
427439 lora_config = kwargs .get ("lora_config" , None )
428440 if lora_config :
@@ -443,7 +455,7 @@ def get_model(model_config: ModelConfig,
443455 # Keep factory to destroy it in a particular moment when all other objects referencing custom nodes are destoyed
444456 pt_model .ov_node_factory = NodeFactory ()
445457 pt_model .ov_node_factory .add_extension ('libuser_ov_extensions.so' )
446- patch_stateful_model (pt_model .model , pt_model .ov_node_factory )
458+ patch_stateful_model (pt_model .model , pt_model .ov_node_factory , kv_cache_dtype )
447459 core = ov .Core ()
448460 ov_compiled = core .compile_model (pt_model .model , "CPU" )
449461 pt_model ._ov_request = ov_compiled .create_infer_request ()
@@ -457,6 +469,6 @@ def get_model(model_config: ModelConfig,
457469 else :
458470 from vllm .model_executor .model_loader import get_model
459471 pt_model = get_model (model_config , device_config , ** kwargs )
460- _patch_model_with_openvino (pt_model , model_config )
472+ _patch_model_with_openvino (pt_model , model_config , kv_cache_dtype )
461473
462474 return pt_model
0 commit comments