Skip to content

Commit 0c5363e

Browse files
committed
fix collation
Signed-off-by: Farzad Abdolhosseini <[email protected]>
1 parent c7e0329 commit 0c5363e

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

vllm/model_executor/models/ultravox.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
44
"""PyTorch Ultravox model."""
5+
import math
56
from functools import cached_property
67
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
78
Tuple, TypedDict, Union)
@@ -46,14 +47,14 @@
4647

4748
class UltravoxAudioFeatureInputs(TypedDict):
4849
type: Literal["audio_features"]
49-
data: torch.Tensor
50+
data: NestedTensors
5051
"""Shape: `(batch_size, num_audios, 80, M)`"""
51-
lens: torch.Tensor
52+
lens: NestedTensors
5253
"""
5354
Length of the audio frames. Used for attention mask in WhisperEncoder.
5455
Shape: `(batch_size)`
5556
"""
56-
token_len: torch.Tensor
57+
token_len: NestedTensors
5758
"""
5859
Length of the audio tokens. Used for flattening the audio features.
5960
Shape: `(batch_size)`
@@ -110,7 +111,11 @@ def get_mm_max_tokens_per_item(
110111
seq_len: int,
111112
mm_counts: Mapping[str, int],
112113
) -> Mapping[str, int]:
113-
return {}
114+
feature_extractor = self.get_feature_extractor()
115+
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
116+
_AUDIO_TOKENS_PER_SECOND)
117+
118+
return {"audio": max_audio_tokens}
114119

115120

116121
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
@@ -422,6 +427,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
422427
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
423428
super().__init__()
424429
config = vllm_config.model_config.hf_config
430+
# Due to the batching of audio chunks, the preprocessor cache cannot
431+
# do the right thing so disable it.
432+
vllm_config.model_config.disable_mm_preprocessor_cache = True
425433
multimodal_config = vllm_config.model_config.multimodal_config
426434
self.config = config
427435
self.multi_modal_config = multimodal_config
@@ -516,10 +524,24 @@ def _process_audio_input(
516524
if audio_input["type"] == "audio_embeds":
517525
return audio_input["data"]
518526

519-
# remove unneeded extra dimension added to all elements of mm_kwargs
520-
audio_features = flatten_bn(audio_input["data"])
521-
audio_lens = flatten_bn(audio_input["lens"])
522-
audio_token_len = flatten_bn(audio_input["token_len"])
527+
audio_features = audio_input["data"]
528+
if isinstance(audio_features, list):
529+
max_len = max(x.shape[-1] for x in audio_features)
530+
# Pad and concatenate:
531+
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
532+
audio_features = torch.cat(
533+
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features])
534+
else:
535+
# Flatten [B, N, 80, M] -> [B * N, 80, M]
536+
audio_features = flatten_bn(audio_features)
537+
538+
if isinstance(audio_input['lens'], list):
539+
# [B1, B2] -> [B1+B2]
540+
audio_lens = torch.cat(audio_input['lens'])
541+
audio_token_len = torch.cat(audio_input['token_len'])
542+
else:
543+
audio_lens = flatten_bn(audio_input['lens'])
544+
audio_token_len = flatten_bn(audio_input['token_len'])
523545

524546
embeddings = self._audio_features_to_embeddings(
525547
audio_features, audio_lens)

0 commit comments

Comments
 (0)