Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cc2df56
update ultravox to accept more than 30s audio
farzadab Feb 20, 2025
c7e0329
temporarily use model with updated processor for tests
farzadab Feb 20, 2025
0c5363e
fix collation
farzadab Feb 21, 2025
189f5cc
revert audio_replacement -> audio_token_replacement
farzadab Feb 25, 2025
bc3ba8c
increase max mm tokens
farzadab Feb 25, 2025
618e752
increase max mm tokens
farzadab Feb 25, 2025
0e62945
reduce max mm tokens
farzadab Feb 25, 2025
69278e2
revert increasing max mm tokens
farzadab Feb 25, 2025
788fd59
Merge remote-tracking branch 'upstream/main' into farzad-long-audio
farzadab Feb 25, 2025
75c138b
fix <|begin_of_text|> not being included
farzadab Feb 25, 2025
3b0e237
batching for whisper to avoid oom
farzadab Feb 26, 2025
97f6f5b
add comment
farzadab Feb 26, 2025
bea5a31
use flat_from_sizes for ultravox mm_fields_config
farzadab Feb 26, 2025
28f16ce
revert ultravox test model id
farzadab Feb 26, 2025
48c359b
improve documentation for double bos_id case
farzadab Feb 26, 2025
e920ab9
Merge remote-tracking branch 'upstream/main' into farzad-long-audio
farzadab Feb 26, 2025
4a54ea1
do not use vocab in get_hf_processor
farzadab Feb 26, 2025
347ada8
revert tests to use v0_5
farzadab Feb 26, 2025
e829dac
Merge remote-tracking branch 'upstream/main' into farzad-long-audio
farzadab Mar 1, 2025
b04878e
revert tests to use v0_5
farzadab Mar 1, 2025
631487f
adding tests for both ultravox v0.4 and v0.5
farzadab Mar 1, 2025
a9828ea
handle audio_num_chunks when no audio is passed
farzadab Mar 3, 2025
33a9cf0
drop test for ultravox v0_4
farzadab Mar 3, 2025
7ca61cf
drop matching Ultravox audio_features with cache
farzadab Mar 4, 2025
48f7da3
ignore exact match for audio_features in _items_by_modality
farzadab Mar 5, 2025
2813a47
fix type hint
farzadab Mar 5, 2025
66c10e4
Merge remote-tracking branch 'vllm-base/main' into farzad-long-audio
farzadab Mar 5, 2025
11ff27f
debug logs for ci
farzadab Mar 10, 2025
2776a31
if all else fails just stack?
farzadab Mar 10, 2025
3cb3583
recursive pad_and_concat
farzadab Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close

MODEL_NAME = "fixie-ai/ultravox-v0_4"
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"

AudioTuple = Tuple[np.ndarray, int]

Expand Down
2 changes: 1 addition & 1 deletion tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_4",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
Expand Down
3 changes: 1 addition & 2 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,7 @@ def check_available_online(
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_4",
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
Expand Down
226 changes: 157 additions & 69 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,23 @@
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
_MAX_ENCODER_BATCH_SIZE = 16


class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)`"""
"""Shape: `(batch_size, num_chunks, 80, M)`"""
lens: NestedTensors
"""
Length of the audio frames. Used for attention mask in WhisperEncoder.
Shape: `(batch_size, num_chunks)`
"""
token_len: NestedTensors
"""
Length of the audio tokens. Used for flattening the audio features.
Shape: `(batch_size, num_chunks)`
"""


class UltravoxAudioEmbeddingInputs(TypedDict):
Expand Down Expand Up @@ -78,6 +89,7 @@ def get_hf_processor(
# token, thus we override placeholder with a reserved special
# token.
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
return hf_processor

def get_feature_extractor(
Expand All @@ -104,7 +116,7 @@ def get_mm_max_tokens_per_item(
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)

return {"audio": max_audio_tokens}
return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE}


class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
Expand All @@ -118,7 +130,8 @@ def get_dummy_processor_inputs(
feature_extractor = self.info.get_feature_extractor()

sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_len = (feature_extractor.chunk_length * sampling_rate *
_MAX_ENCODER_BATCH_SIZE)
num_audios = mm_counts.get("audio", 0)

mm_data = {
Expand Down Expand Up @@ -160,41 +173,38 @@ def _call_hf_processor(
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
include_audio_num_chunks=True,
)

# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
audio_features, audio_token_len = [], []
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**mm_data, audio=audio)

item_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)

audio_features.append(item_outputs.pop("audio_values")[0])
audio_token_len.append(item_outputs.pop("audio_token_len").item())
shared_outputs = item_outputs
item_processor_data = dict(**mm_data, audios=audios)

combined_outputs = dict(
**shared_outputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
output = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)
return BatchFeature(combined_outputs)
output['audio_features'] = output.pop('audio_values')

return output

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0))
return dict(
audio_features=MultiModalFieldConfig.batched("audio"),
audio_token_len=MultiModalFieldConfig.batched("audio"),
# to handle longer than 30s audio, each audio might be split
# into multiple chunks as such, their batch dimension can be
# higher than the number of audio samples
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", num_chunks),
audio_token_len=MultiModalFieldConfig.flat_from_sizes(
"audio", num_chunks),
audio_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", num_chunks),
# num_chunks can convert audio_chunked to audio batch dimension
audio_num_chunks=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
)

Expand All @@ -205,14 +215,22 @@ def _get_prompt_updates(
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()

replacement_id = vocab[
hf_processor.audio_token_replacement] # type: ignore
replacement_id = hf_processor.audio_replacement_token_id # type: ignore

# Each audio can be split into multiple chunks.
# chunks_start_idx[i] indicates the start index of the chunks
# belonging to the i-th audio.
chunks_start_idx: torch.Tensor = torch.cumsum(
out_mm_kwargs["audio_num_chunks"], dim=0, dtype=torch.int32)
chunks_start_idx = torch.cat(
[torch.tensor([0], dtype=torch.int32), chunks_start_idx])
out_mm_kwargs.pop("audio_num_chunks", None)

def get_replacement_ultravox(item_idx: int):
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
start = chunks_start_idx[item_idx]
end = chunks_start_idx[item_idx + 1]
audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum()
return [replacement_id] * int(audio_token_len) # type: ignore

return [
Expand Down Expand Up @@ -304,12 +322,49 @@ class ModifiedWhisperEncoder(WhisperEncoder):

base_model_prefix = "model.encoder"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config.is_decoder = False

@property
def max_context_length(self):
return (self.config.max_source_positions * self.conv1.stride[0] *
self.conv2.stride[0])

def get_attention_mask_by_audio_len(self,
audio_lens: Optional[torch.Tensor],
hidden_states: torch.Tensor):
"""
Create attention mask based on audio lengths to mask out padding tokens
For each sample in batch:
- Convert raw audio length to feature length after convolutions
- Create bool mask: True for valid positions and False for padding
- Convert to attention mask format expected by transformer layers
(1.0 for positions to attend to, large negative for positions to ignore)
This masking ensures consistent behavior between training and inference
by preventing the model from attending to padding tokens in both cases
"""
if audio_lens is None:
return None

audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
max_seq_len = hidden_states.shape[1]
attention_mask = torch.arange(max_seq_len,
device=hidden_states.device)[None, :].lt(
audio_feature_len.view(-1, 1))
attention_mask = self.get_extended_attention_mask(
attention_mask,
None,
dtype=hidden_states.dtype,
)
return attention_mask

def forward(
self,
input_features,
input_features: torch.Tensor,
audio_lens: Optional[torch.Tensor] = None,
):
expected_seq_length = (self.config.max_source_positions *
self.conv1.stride[0] * self.conv2.stride[0])
expected_seq_length = self.max_context_length
if input_features.shape[-1] > expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length "
Expand All @@ -328,10 +383,13 @@ def forward(
p=self.dropout,
training=self.training)

attention_mask = self.get_attention_mask_by_audio_len(
audio_lens, hidden_states)

for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
None,
attention_mask,
layer_head_mask=None,
)

Expand Down Expand Up @@ -409,17 +467,34 @@ def get_mm_mapping(self) -> MultiModelKeys:
)

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
audio_features = self.audio_tower(audio_input)
audio_features = audio_features.to(self.audio_tower.dtype)
audio_embeddings = self.multi_modal_projector(audio_features)
self, input_features: torch.Tensor,
audio_lens: torch.Tensor) -> torch.Tensor:
audio_features = input_features.to(self.audio_tower.dtype)
batch_size = audio_features.size(0)
audio_embeddings = []

# Process audio features in batches to keep memory usage predictable
for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
# Process through audio tower
batch_features = self.audio_tower(audio_features[start:end],
audio_lens[start:end])
batch_features = batch_features.to(self.audio_tower.dtype)

# Process through projector
batch_embeddings = self.multi_modal_projector(batch_features)
audio_embeddings.append(batch_embeddings)

# Concatenate results
audio_embeddings = torch.cat(audio_embeddings, dim=0)
return audio_embeddings

def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
audio_lens = kwargs.pop("audio_lens", None)
audio_token_len = kwargs.pop("audio_token_len", None)

if audio_features is None and audio_embeds is None:
return None
Expand All @@ -430,7 +505,9 @@ def _parse_and_validate_audio_input(
f"Got type: {type(audio_features)}")

return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features)
data=audio_features,
lens=audio_lens,
token_len=audio_token_len)

if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
Expand All @@ -448,33 +525,40 @@ def _process_audio_input(
return audio_input["data"]

audio_features = audio_input["data"]
if isinstance(audio_features, torch.Tensor):
# Combine the B and N dimensions for the encoder/projector
flattened = flatten_bn(audio_features)
flattened_embeddings = self._audio_features_to_embeddings(
flattened)

# Restore the original dimensions
embeddings = flattened_embeddings.unflatten(
0, audio_features.shape[:2])
return embeddings

result = []
# TODO: Batch heterogeneous tensors through the encoder/projector
for audio_features_item in audio_features:
if isinstance(audio_features_item, torch.Tensor):
result.append(
self._audio_features_to_embeddings(audio_features_item))
else:
embeddings = [
# Add a batch dimension to embed it, then remove it.
self._audio_features_to_embeddings(tensor.unsqueeze(0)
).squeeze(0)
for tensor in audio_features_item
]
result.append(embeddings)
if isinstance(audio_features, list):
max_len = max(x.shape[-1] for x in audio_features)
# Pad and concatenate:
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
audio_features = torch.cat(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features])
else:
# Flatten [B, N, 80, M] -> [B * N, 80, M]
audio_features = flatten_bn(audio_features)

return result
if isinstance(audio_input['lens'], list):
# [B1, B2] -> [B1+B2]
audio_lens = torch.cat(audio_input['lens'])
audio_token_len = torch.cat(audio_input['token_len'])
else:
audio_lens = flatten_bn(audio_input['lens'])
audio_token_len = flatten_bn(audio_input['token_len'])

embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens)

# We should flatten and concatenate embeddings based on token lengths
# For example, with token_len = [4, 2, 3], flattened_embeddings will be
# concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])

# Create a mask of valid indices based on token lengths
max_len = embeddings.shape[1]
indices = torch.arange(max_len, device=embeddings.device).expand(
embeddings.shape[0], -1)
mask = indices < audio_token_len[:, None]
# Apply mask and flatten
flattened_embeddings = embeddings[mask]

return flattened_embeddings

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
Expand Down Expand Up @@ -519,7 +603,11 @@ def forward(self,
with the `input_ids`.

Args:
audio_features: A batch of audio inputs [B, N, 80, M].
audio_features: A batch of audio input chunks [B, N, 80, M].
audio_lens: Length of audio frames for each audio chunk [B].
audio_token_len: Length of audio tokens for each audio chunk [B'].
Note: batch dim is different from batch dim in audio chunks.

"""

if intermediate_tensors is not None:
Expand Down