Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Model] Extend Ultravox to accept audio longer than 30s #13631

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

farzadab
Copy link
Contributor

@farzadab farzadab commented Feb 20, 2025

Currently the Ultravox model input is capped to 30 seconds and extra audio is truncated (AFAIK). Also each sample is fed to Whisper individually (without being batched).

This PR allows using longer audio by chunking them first, using Whisper encoder in batch mode, and then concatenates them.

TODO:

  • processors on HF still need to be updated in tandem with this PR.
  • run evaluations with the updated model to verify the changes.
  • add test with long audio

@farzadab farzadab marked this pull request as draft February 20, 2025 21:29
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@@ -15,7 +15,7 @@
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close

MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
MODEL_NAME = "fixie-ai/ultravox-v0_3-llama-3_2-1b"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary and will be reverted.

Comment on lines +173 to +176
output = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Processor is updated to handle multiple audio. This is still a bit of a WIP since 1) not all models are updated, only the v0_3 version of 1B is updated, and 2) during this PR I realized that the new processor will break this VLLM implementation, so I have to figure out what to do there.

Comment on lines 102 to 113
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)

return {"audio": max_audio_tokens}
return {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm doing the right thing here. I just want to say there's no limit to input audio anymore.

@mgoin
Copy link
Member

mgoin commented Feb 21, 2025

FYI @NickLucche for the usage of whisper

@NickLucche
Copy link
Contributor

Thanks for the contrib!
What is the chunking logic for tiling the audio? Feel free to link the hf processor PR.

Comment on lines +198 to +205
# to handle longer than 30s audio, each audio might be split
# into multiple chunks as such, their batch dimension can be
# higher than the number of audio samples
audio_features=MultiModalFieldConfig.batched("audio_chunked"),
audio_token_len=MultiModalFieldConfig.batched("audio_chunked"),
audio_lens=MultiModalFieldConfig.batched("audio_chunked"),
# num_chunks can convert audio_chunked to audio batch dimension
audio_num_chunks=MultiModalFieldConfig.batched("audio"),
Copy link
Contributor Author

@farzadab farzadab Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be doing something wrong here, perhaps I should do .flat_from_sizes but I'm not sure how that works closely.

Copy link
Member

@DarkLight1337 DarkLight1337 Feb 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use .flat_from_sizes in the processor to represent variable batch size per audio (assuming that the processor concatenates the result across each audio). The processor cache should still work correctly as long as the processor output for a given audio input doesn't depend on other audio inputs.

Inside the model, you can torch.split them according to audio_num_chunks to recover the original batches.

@farzadab
Copy link
Contributor Author

farzadab commented Feb 21, 2025

re @NickLucche: Here's the processor link: https://huggingface.co/fixie-ai/ultravox-v0_3-llama-3_2-1b/blob/main/ultravox_processing.py#L209

The logic: for each audio, split to 30 second chunks (but do not pad the last item to 30s, which is the same as before).
Then we flatten and batch everything up and run Whisper as if they were separate audios. We use audio_lens to compute an attention_mask for the last chunk per audio. The final embeddings are then concatenated.

There are other ways we could've done this, but it matches what we do on the Ultravox side for both some fine-tuning that we do and evals. If we end up updating those I'll update VLLM as well.

Also, note that since we don't pad the last chunk, and since in most cases we have smaller than 30s audio, the number of frames do not match across samples. I didn't see a collator anywhere that I could update. I'm suspecting that I'll have to update _process_audio_input further to handle that. Updated _process_audio_input.

Signed-off-by: Farzad Abdolhosseini <[email protected]>
@NickLucche
Copy link
Contributor

Ok I see then that's a naive chunking where you don't account for splitting mid-word nor you have any overlap and/or prompt from previous chunk.

This case seems much easier to handle vllm-side, given changes are already in hf. Let's just make sure the batched whisper forward is accounted for by the initial profiler run to avoid oom.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants