|
2 | 2 |
|
3 | 3 | # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py |
4 | 4 | """PyTorch Ultravox model.""" |
| 5 | +import math |
5 | 6 | from functools import cached_property |
6 | 7 | from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, |
7 | 8 | Tuple, TypedDict, Union) |
|
46 | 47 |
|
47 | 48 | class UltravoxAudioFeatureInputs(TypedDict): |
48 | 49 | type: Literal["audio_features"] |
49 | | - data: torch.Tensor |
| 50 | + data: NestedTensors |
50 | 51 | """Shape: `(batch_size, num_audios, 80, M)`""" |
51 | | - lens: torch.Tensor |
| 52 | + lens: NestedTensors |
52 | 53 | """ |
53 | 54 | Length of the audio frames. Used for attention mask in WhisperEncoder. |
54 | 55 | Shape: `(batch_size)` |
55 | 56 | """ |
56 | | - token_len: torch.Tensor |
| 57 | + token_len: NestedTensors |
57 | 58 | """ |
58 | 59 | Length of the audio tokens. Used for flattening the audio features. |
59 | 60 | Shape: `(batch_size)` |
@@ -110,7 +111,11 @@ def get_mm_max_tokens_per_item( |
110 | 111 | seq_len: int, |
111 | 112 | mm_counts: Mapping[str, int], |
112 | 113 | ) -> Mapping[str, int]: |
113 | | - return {} |
| 114 | + feature_extractor = self.get_feature_extractor() |
| 115 | + max_audio_tokens = math.ceil(feature_extractor.chunk_length * |
| 116 | + _AUDIO_TOKENS_PER_SECOND) |
| 117 | + |
| 118 | + return {"audio": max_audio_tokens} |
114 | 119 |
|
115 | 120 |
|
116 | 121 | class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] |
@@ -422,6 +427,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): |
422 | 427 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
423 | 428 | super().__init__() |
424 | 429 | config = vllm_config.model_config.hf_config |
| 430 | + # Due to the batching of audio chunks, the preprocessor cache cannot |
| 431 | + # do the right thing so disable it. |
| 432 | + vllm_config.model_config.disable_mm_preprocessor_cache = True |
425 | 433 | multimodal_config = vllm_config.model_config.multimodal_config |
426 | 434 | self.config = config |
427 | 435 | self.multi_modal_config = multimodal_config |
@@ -516,10 +524,24 @@ def _process_audio_input( |
516 | 524 | if audio_input["type"] == "audio_embeds": |
517 | 525 | return audio_input["data"] |
518 | 526 |
|
519 | | - # remove unneeded extra dimension added to all elements of mm_kwargs |
520 | | - audio_features = flatten_bn(audio_input["data"]) |
521 | | - audio_lens = flatten_bn(audio_input["lens"]) |
522 | | - audio_token_len = flatten_bn(audio_input["token_len"]) |
| 527 | + audio_features = audio_input["data"] |
| 528 | + if isinstance(audio_features, list): |
| 529 | + max_len = max(x.shape[-1] for x in audio_features) |
| 530 | + # Pad and concatenate: |
| 531 | + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] |
| 532 | + audio_features = torch.cat( |
| 533 | + [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features]) |
| 534 | + else: |
| 535 | + # Flatten [B, N, 80, M] -> [B * N, 80, M] |
| 536 | + audio_features = flatten_bn(audio_features) |
| 537 | + |
| 538 | + if isinstance(audio_input['lens'], list): |
| 539 | + # [B1, B2] -> [B1+B2] |
| 540 | + audio_lens = torch.cat(audio_input['lens']) |
| 541 | + audio_token_len = torch.cat(audio_input['token_len']) |
| 542 | + else: |
| 543 | + audio_lens = flatten_bn(audio_input['lens']) |
| 544 | + audio_token_len = flatten_bn(audio_input['token_len']) |
523 | 545 |
|
524 | 546 | embeddings = self._audio_features_to_embeddings( |
525 | 547 | audio_features, audio_lens) |
|
0 commit comments