Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support llava onevision #2783

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
<td>
<ul>
<li>LLaVA(1.5,1.6) (7B-34B)</li>
<li>LLaVA-OneVision (0.5B, 7B, 72B)</li>
<li>InternLM-XComposer2 (7B, 4khd-7B)</li>
<li>InternLM-XComposer2.5 (7B)</li>
<li>Qwen-VL (7B)</li>
Expand Down
2 changes: 2 additions & 0 deletions README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
<td>
<ul>
<li>LLaVA(1.5,1.6) (7B-34B)</li>
<li>LLaVA-OneVision (0.5B, 7B, 72B)</li>
<li>InternLM-XComposer2 (7B, 4khd-7B)</li>
<li>InternLM-XComposer2.5 (7B)</li>
<li>Qwen-VL (7B)</li>
<li>Qwen2-VL (2B, 7B, 72B)</li>
<li>DeepSeek-VL (7B)</li>
<li>InternVL-Chat (v1.1-v1.5)</li>
<li>InternVL2 (1B-76B)</li>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
<td>
<ul>
<li>LLaVA(1.5,1.6) (7B-34B)</li>
<li>LLaVA-OneVision (0.5B, 7B, 72B)</li>
<li>InternLM-XComposer2 (7B, 4khd-7B)</li>
<li>InternLM-XComposer2.5 (7B)</li>
<li>Qwen-VL (7B)</li>
Expand Down
17 changes: 9 additions & 8 deletions docs/en/multi_modal/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

LMDeploy supports the following llava series of models, which are detailed in the table below:

| Model | Size | Supported Inference Engine |
| :----------------------------------: | :--: | :------------------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
| Model | Size | Supported Inference Engine |
| :----------------------------------: | :---------: | :------------------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
| lmms-lab/llava-onevision-qwen2-7b-ov | 0.5B,7B,72B | TurboMind |

The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example.

Expand Down
17 changes: 9 additions & 8 deletions docs/zh_cn/multi_modal/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示:

| 模型 | 大小 | 支持的推理引擎 |
| :----------------------------------: | :--: | :----------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
| 模型 | 大小 | 支持的推理引擎 |
| :----------------------------------: | :---------: | :----------------: |
| llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch |
| llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch |
| llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch |
| liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind |
| liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind |
| lmms-lab/llava-onevision-qwen2-7b-ov | 0.5B,7B,72B | TurboMind |

接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。

Expand Down
12 changes: 6 additions & 6 deletions lmdeploy/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ def check_vl_llm(config: dict) -> bool:
arch = config['architectures'][0]
supported_archs = set([
'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM',
'CogVLMForCausalLM', 'InternLMXComposer2ForCausalLM',
'InternVLChatModel', 'MiniGeminiLlamaForCausalLM',
'MGMLlamaForCausalLM', 'MiniCPMV', 'LlavaForConditionalGeneration',
'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM',
'Qwen2VLForConditionalGeneration', 'MllamaForConditionalGeneration',
'MolmoForCausalLM'
'LlavaQwenForCausalLM', 'CogVLMForCausalLM',
'InternLMXComposer2ForCausalLM', 'InternVLChatModel',
'MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM', 'MiniCPMV',
'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration',
'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration',
'MllamaForConditionalGeneration', 'MolmoForCausalLM'
])
if arch == 'QWenLMHeadModel' and 'visual' in config:
return True
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# llava
LlavaLlamaForCausalLM='llama',
LlavaMistralForCausalLM='llama',
LlavaQwenForCausalLM='qwen2',
LlavaForConditionalGeneration='llava',
# xcomposer2
InternLMXComposer2ForCausalLM='xcomposer2',
Expand Down Expand Up @@ -95,7 +96,10 @@ def _is_head_dim_supported(cfg):
if num_attn_head == 40:
# baichuan-13B, baichuan2-13B not supported by turbomind
support_by_turbomind = False
elif arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
elif arch in [
'Qwen2ForCausalLM', 'LlamaForCausalLM',
'LlavaQwenForCausalLM'
]:
support_by_turbomind = _is_head_dim_supported(cfg)
elif arch in ('ChatGLMModel', 'ChatGLMForConditionalGeneration'):
# chatglm1/2/3 is not working yet
Expand Down
85 changes: 69 additions & 16 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# Modified from https://github.com/haotian-liu/LLaVA.git
import ast
import math
import re
import warnings
from contextlib import contextmanager
from typing import Dict, List

import torch
from PIL import Image
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
Expand All @@ -24,7 +26,7 @@ def check_llava_install():
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
'`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`' # noqa: E501
'pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git --no-deps' # noqa: E501
)


Expand Down Expand Up @@ -159,11 +161,19 @@ def process_anyres_image(image, processor, grid_pinpoints):
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)

patches = divide_to_patches(image_padded, processor.crop_size['height'])

image_original_resize = image.resize(
(processor.size['shortest_edge'], processor.size['shortest_edge']))
if processor.__class__.__name__ == 'SiglipImageProcessor':
patches = divide_to_patches(image_padded, processor.size['height'])
image_original_resize = image.resize(
(processor.size['height'], processor.size['width']))
elif processor.__class__.__name__ == 'CLIPImageProcessor':
patches = divide_to_patches(image_padded,
processor.crop_size['height'])

image_original_resize = image.resize(
(processor.size['shortest_edge'], processor.size['shortest_edge']))
else:
raise NotImplementedError(
f'Not support image_processor:{processor.image_processor_type}')

image_patches = [image_original_resize] + patches
image_patches = [
Expand Down Expand Up @@ -198,7 +208,7 @@ def process_images(images, image_processor, model_cfg):
image = image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == 'anyres':
elif image_aspect_ratio == 'anyres' or 'anyres_max' in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor,
model_cfg.image_grid_pinpoints)
Expand All @@ -218,7 +228,10 @@ class LlavaVisionModel(LlavaHfVisionModel):
def match(cls, config: AutoConfig):
"""check whether the config match the model."""
arch = config.architectures[0]
if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']:
if arch in [
'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM',
'LlavaQwenForCausalLM'
]:
# internvl-llava has vision_tower of OpenGVLab/xxx
mm_vision_tower = getattr(config, 'mm_vision_tower', '')
# yi-vl has projector type of xxx_Norm
Expand All @@ -231,8 +244,8 @@ def match(cls, config: AutoConfig):
return False

def build_preprocessor(self):
from transformers import CLIPImageProcessor
self.image_processor = CLIPImageProcessor.from_pretrained(
from transformers import AutoImageProcessor
self.image_processor = AutoImageProcessor.from_pretrained(
self.hf_config.mm_vision_tower)
config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower)
image_size = config.vision_config.image_size
Expand All @@ -258,10 +271,13 @@ def build_model(self):
from llava.model.language_model.llava_mistral import \
LlavaMistralConfig
self.config = LlavaMistralConfig.from_pretrained(self.model_path)
elif self.arch == 'LlavaQwenForCausalLM':
from llava.model.language_model.llava_qwen import LlavaQwenConfig
self.config = LlavaQwenConfig.from_pretrained(self.model_path)
else:
assert 0, f'unsupported arch {self.arch}'

from accelerate import init_empty_weights
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# init empty model, skip layer initialization
with init_empty_weights(), warnings.catch_warnings(), \
Expand Down Expand Up @@ -289,14 +305,19 @@ def build_model(self):
# for llava-v1.5, the vit is not in llm ckpt
vision_tower.to(dtype=torch.half)

from accelerate import load_checkpoint_and_dispatch
setattr(model.config, 'tie_word_embeddings', False)
no_split_module_classes = ['CLIPEncoderLayer', 'SiglipEncoderLayer']
same_device_keys = [('mm_projector', 'vision_resampler',
'image_newline', 'rotary_emb')]
device_map = self.get_vision_encoder_device_map(
model, self.max_memory, no_split_module_classes, same_device_keys)
with disable_logging():
load_checkpoint_and_dispatch(
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['CLIPEncoderLayer'],
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)

self.model = model.model.eval()
Expand Down Expand Up @@ -374,7 +395,16 @@ def forward(self,
height = self.vision_tower.num_patches_per_side
width = self.vision_tower.num_patches_per_side
assert height * width == base_feat.shape[0]
if image_aspect_ratio == 'anyres':
# https://github.com/LLaVA-VL/LLaVA-NeXT/blob/79ef45a6d8b89b92d7a8525f077c3a3a9894a87d/llava/model/llava_arch.py#L357-L410
if 'anyres_max' in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r'anyres_max_(\d+)', image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(
matched_anyres_max_num_patches.group(
1))

if image_aspect_ratio == 'anyres' or 'anyres_max' in image_aspect_ratio: # noqa: E501
num_patch_width, num_patch_height = \
get_anyres_image_grid_shape(
image_sizes[img_idx],
Expand All @@ -385,7 +415,30 @@ def forward(self,
width, -1)
else:
raise NotImplementedError
if 'unpad' in mm_patch_merge_type:
if 'unpad' in mm_patch_merge_type \
and 'anyres_max' in image_aspect_ratio \
and matched_anyres_max_num_patches:
unit = feat.shape[2]
feat = feat.permute(4, 0, 2, 1, 3).contiguous()
feat = feat.flatten(1, 2).flatten(2, 3)
feat = unpad_image(feat, image_sizes[img_idx])
c, h, w = feat.shape
times = math.sqrt(h * w /
(max_num_patches * unit**2))
if times > 1.1:
feat = feat[None]
feat = nn.functional.interpolate(
feat,
[int(h // times),
int(w // times)],
mode='bilinear')[0]
feat = torch.cat(
(feat, self.model.
image_newline[:, None, None].expand(
*feat.shape[:-1], 1).to(feat.device)),
dim=-1)
feat = feat.flatten(1, 2).transpose(0, 1)
elif 'unpad' in mm_patch_merge_type:
feat = feat.permute(4, 0, 2, 1, 3).contiguous()
feat = feat.flatten(1, 2).flatten(2, 3)
feat = unpad_image(feat, image_sizes[img_idx])
Expand Down
39 changes: 35 additions & 4 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ def build_model(self):

# fix for llava-hf/llava-interleave-qwen-7b-hf
setattr(model.config, 'tie_word_embeddings', False)
no_split_module_classes = ['CLIPEncoderLayer', 'SiglipEncoderLayer']
device_map = self.get_vision_encoder_device_map(
model, self.max_memory, no_split_module_classes)
with disable_logging():
load_checkpoint_and_dispatch(
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=[
'CLIPEncoderLayer', 'SiglipEncoderLayer'
],
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)
model.eval()
self.model = model
Expand Down Expand Up @@ -152,3 +153,33 @@ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
sequence_start)
return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
sequence_start)

@staticmethod
def get_vision_encoder_device_map(
model,
max_memory,
no_split_module_classes=['CLIPEncoderLayer', 'SiglipEncoderLayer'],
same_device_keys=None):
"""map vision_encoder to same device."""
from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
dtype=torch.half,
no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
model,
no_split_module_classes=no_split_module_classes,
max_memory=max_memory,
dtype=torch.half)

if not same_device_keys:
return device_map

for keys in same_device_keys:
fuzzy_keys = [kk for kk in device_map for k in keys if kk.find(k)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuzzy_keys = [kk for kk in device_map for k in keys if kk.find(k) != -1] ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after changing it to

fuzzy_keys = [kk for kk in device_map for k in keys if k in kk]

it suffers RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:2!

I will fix it later on

if len(fuzzy_keys) <= 1:
continue
for k in fuzzy_keys[1:]:
device_map[k] = device_map[fuzzy_keys[0]]
return device_map
21 changes: 3 additions & 18 deletions lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,12 @@ def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import load_checkpoint_and_dispatch
from accelerate.utils import get_balanced_memory, infer_auto_device_map

no_split_module_classes = ['CLIPEncoderLayer']
max_memory = get_balanced_memory(
self.model,
max_memory=self.max_memory,
dtype=torch.half,
no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
self.model,
no_split_module_classes=no_split_module_classes,
max_memory=max_memory,
dtype=torch.half)

same_device_keys = [('multi_modal_projector', 'image_newline')]
for keys in same_device_keys:
keys = [k for k in keys if k in device_map]
if len(keys) <= 1:
continue
for k in keys[1:]:
device_map[k] = device_map[keys[0]]
device_map = self.get_vision_encoder_device_map(
self.model, self.max_memory, no_split_module_classes,
same_device_keys)

with disable_logging():
load_checkpoint_and_dispatch(
Expand Down
Loading