Skip to content

Commit b21585a

Browse files
petersalasJC1DA
authored andcommitted
[Model][OpenVINO] Fix regressions from vllm-project#8346 (vllm-project#10045)
Signed-off-by: Peter Salas <[email protected]> Signed-off-by: Loc Huynh <[email protected]>
1 parent 2f5095b commit b21585a

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

.buildkite/run-openvino-test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ trap remove_docker_container EXIT
1111
remove_docker_container
1212

1313
# Run the image and launch offline inference
14-
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py
14+
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference.py

vllm/attention/backends/openvino.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from dataclasses import dataclass
2-
from typing import List, Tuple, Type
2+
from typing import Dict, List, Optional, Tuple, Type
33

44
import openvino as ov
55
import torch
66

77
from vllm.attention.backends.abstract import (AttentionBackend,
88
AttentionMetadata)
99
from vllm.attention.backends.utils import CommonAttentionState
10+
from vllm.multimodal import MultiModalPlaceholderMap
1011

1112

1213
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
@@ -128,3 +129,12 @@ class OpenVINOAttentionMetadata:
128129
# Shape: scalar
129130
# Type: i32
130131
max_context_len: torch.Tensor
132+
133+
# The index maps that relate multi-modal embeddings to the corresponding
134+
# placeholders.
135+
#
136+
# N.B. These aren't really related to attention and don't belong on this
137+
# type -- this is just a temporary solution to make them available to
138+
# `model_executable`.
139+
multi_modal_placeholder_index_maps: Optional[Dict[
140+
str, MultiModalPlaceholderMap.IndexMap]]

vllm/model_executor/models/molmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
get_tensor_model_parallel_world_size,
2222
split_tensor_along_last_dim,
2323
tensor_model_parallel_all_gather)
24-
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
25-
token_inputs)
24+
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
25+
InputContext, token_inputs)
2626
from vllm.model_executor import SamplingMetadata
2727
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
2828
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -915,7 +915,7 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
915915
if "image_masks" in out:
916916
dummy_imgdata["image_masks"] = out["image_masks"]
917917
dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
918-
return dummy_seqdata, {"image": dummy_imgdata}
918+
return DummyData(dummy_seqdata, {"image": dummy_imgdata})
919919

920920

921921
def pad_images(

0 commit comments

Comments
 (0)