Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][VLM] Add LLaVA-Onevision model support #8486

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

litianjian
Copy link

@litianjian litianjian commented Sep 14, 2024

This PR adding support for LLaVA-OneVision model.

FIX #7420

Requrements

This PR requires transformers with this PR merged(You can install it via pip install git+https://github.com/huggingface/transformers)

Example Usage

import av
import time
import numpy as np
from huggingface_hub import hf_hub_download
import vllm
from vllm import LLM, SamplingParams

MODEL="llava-hf/llava-onevision-qwen2-7b-ov-hf"

text_prompt = "<|im_start|>user <video>\nPlease provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.<|im_end|><|im_start|>assistant\n"

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 32).astype(int)
video = read_video_pyav(container, indices)

llm = LLM(model=MODEL, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0.8,
                            top_p=0.95,
                            max_tokens=100)
outputs = llm.generate(
    {
        "prompt": text_prompt,
        "multi_modal_data": {
            "video": video
        }
    },
    sampling_params=sampling_params)

generated_text = ""
for o in outputs:
    generated_text += o.outputs[0].text
print(f"LLM output:{generated_text}")

Roadmap

  • Support LLaVA-OneVision model with LlavaOnevisionForConditionalGeneration .
  • Support image and video inputs.
  • Support image-video-mixed inputs.

Notes

  • LLaVA-OneVision Repo supports more configs than HF, HF using the default values in released version, such as spatial_pool_mode, spatial_pool_stride, mm_newline_position . We will follow up with updates from Huggingface.
  • post_layer_norm issues in LLaVA-OneVision may be updated in other PRs.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337 DarkLight1337 changed the title [Model][VLM] Add Qwen2-VL model support [Model][VLM] Add LLaVA-OneVision model support Sep 14, 2024
@DarkLight1337 DarkLight1337 changed the title [Model][VLM] Add LLaVA-OneVision model support [Model][VLM] Add LLaVA-Onevision model support Sep 14, 2024
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@litianjian Thank you for making this contribution! I did a first pass so please take a look!

vllm/model_executor/models/clip.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
Comment on lines 190 to 199
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos > _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")

# TODO: support configuring the number of frames
frames_per_video = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(
ctx, frames_per_video)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a question for later if needed: Now I think about it, do you see any issue with setting this check based on total number of frames?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most VLM, the maximum of frames depends on the tokens per frame and context length in LLM.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, I was just thinking if it makes sense to do it by total_num_frames instead of num_videos and frames_per_video. This will indeed make the validation more complicated, and we can definitely discuss more later!

My understanding is that, in a multi-video setting, one should be able to send two videos of 5 and 8 frames or 4 and 9 frames since to the sequence they're both equivalent to inference with 13 images. Is that correct?

vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llava_onevision.py Outdated Show resolved Hide resolved
vllm/model_executor/models/siglip.py Outdated Show resolved Hide resolved
@ywang96
Copy link
Member

ywang96 commented Sep 16, 2024

@litianjian FYI I'm going to spend some time testing this PR this week. Overall it's in a good shape so we can probably get it merged by the end of this week/early next week!

@ywang96
Copy link
Member

ywang96 commented Sep 19, 2024

Hey @litianjian! I have finished some testing on this PR. A few questions:

  1. When I remove the pymark.skip decorator for the test and run them in my dev environment (H100), I got error (RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same), not sure why this happened when the video example works.
  2. I updated the example file so that the user can run image/video inference with the same model, but I don't seem to get image inference working. Could you take a look? (It would be great if you can also add an image test too)
  3. Do you plan to add image-video mixed input support in this PR, or in a later one? (or is it supported by the model at all?)

@litianjian
Copy link
Author

Hey @litianjian! I have finished some testing on this PR. A few questions:

  1. When I remove the pymark.skip decorator for the test and run them in my dev environment (H100), I got error (RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same), not sure why this happened when the video example works.
  2. I updated the example file so that the user can run image/video inference with the same model, but I don't seem to get image inference working. Could you take a look? (It would be great if you can also add an image test too)
  3. Do you plan to add image-video mixed input support in this PR, or in a later one? (or is it supported by the model at all?)

Thank you for your patience. I will solve the problem 1 and 2, and update APSP. The model doesn't support the image-video mixed now, I can update in another PR if it adds the support in at some point in the future.

@litianjian
Copy link
Author

Hey @litianjian! I have finished some testing on this PR. A few questions:

  1. When I remove the pymark.skip decorator for the test and run them in my dev environment (H100), I got error (RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same), not sure why this happened when the video example works.
  2. I updated the example file so that the user can run image/video inference with the same model, but I don't seem to get image inference working. Could you take a look? (It would be great if you can also add an image test too)
  3. Do you plan to add image-video mixed input support in this PR, or in a later one? (or is it supported by the model at all?)

1、For RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same , this issue is trigged by HFRunner in conftest.
2、I have solved the problem 2, and the example file can run image/video inference with the same model correctly.

@ywang96
Copy link
Member

ywang96 commented Sep 20, 2024

Hey @litianjian! I have finished some testing on this PR. A few questions:

  1. When I remove the pymark.skip decorator for the test and run them in my dev environment (H100), I got error (RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same), not sure why this happened when the video example works.
  2. I updated the example file so that the user can run image/video inference with the same model, but I don't seem to get image inference working. Could you take a look? (It would be great if you can also add an image test too)
  3. Do you plan to add image-video mixed input support in this PR, or in a later one? (or is it supported by the model at all?)

1、For RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same , this issue is trigged by HFRunner in conftest. 2、I have solved the problem 2, and the example file can run image/video inference with the same model correctly.

@litianjian Thanks! I can confirm image example can run correctly too.

I just pushed a change 92c827f to clean up some code, as well as adding an image test in text_llava_onevision.py. For some reason there's an issue with transformers SigLIPVisionModel that caused the dtype mismatch.

Given this is not an issue on vLLM, and this PR came directly from model vendor, I'm okay with merge this PR as long as you think the model quality doesn't need to be checked for now.

@litianjian litianjian closed this Sep 20, 2024
@litianjian litianjian reopened this Sep 20, 2024
@litianjian
Copy link
Author

Hey @litianjian! I have finished some testing on this PR. A few questions:

  1. When I remove the pymark.skip decorator for the test and run them in my dev environment (H100), I got error (RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same), not sure why this happened when the video example works.
  2. I updated the example file so that the user can run image/video inference with the same model, but I don't seem to get image inference working. Could you take a look? (It would be great if you can also add an image test too)
  3. Do you plan to add image-video mixed input support in this PR, or in a later one? (or is it supported by the model at all?)

1、For RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same , this issue is trigged by HFRunner in conftest. 2、I have solved the problem 2, and the example file can run image/video inference with the same model correctly.

@litianjian Thanks! I can confirm image example can run correctly too.

I just pushed a change 92c827f to clean up some code, as well as adding an image test in text_llava_onevision.py. For some reason there's an issue with transformers SigLIPVisionModel that caused the dtype mismatch.

Given this is not an issue on vLLM, and this PR came directly from model vendor, I'm okay with merge this PR as long as you think the model quality doesn't need to be checked for now.

Thank you for your patience. It's ok for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: LLaVA-OneVision
3 participants