diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 102842b0a188d..7f1b2443824a2 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -3,7 +3,7 @@ Supported Models ================ -vLLM supports a variety of generative Transformer models in `HuggingFace Transformers `_. +vLLM supports a variety of generative Transformer models in `HuggingFace (HF) Transformers `_. The following is the list of model architectures that are currently supported by vLLM. Alongside each architecture, we include some popular models that use it. @@ -19,7 +19,7 @@ Text Generation * - Architecture - Models - - Example HuggingFace Models + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`AquilaForCausalLM` @@ -280,7 +280,7 @@ Text Embedding * - Architecture - Models - - Example HuggingFace Models + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`Gemma2Model` @@ -303,7 +303,7 @@ Reward Modeling * - Architecture - Models - - Example HuggingFace Models + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`Qwen2ForRewardModel` @@ -316,7 +316,14 @@ Reward Modeling As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. Multimodal Language Models -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following modalities are supported depending on the model: + +- **T**\ ext +- **I**\ mage +- **V**\ ideo +- **A**\ udio .. _supported_vlms: @@ -324,78 +331,78 @@ Text Generation --------------- .. list-table:: - :widths: 25 25 25 25 5 5 + :widths: 25 25 15 25 5 5 :header-rows: 1 * - Architecture - Models - - Modalities - - Example HuggingFace Models + - Inputs + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - - Image\ :sup:`E` + - T + I\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - - ✅︎ * - :code:`ChameleonForConditionalGeneration` - Chameleon - - Image + - T + I - :code:`facebook/chameleon-7b` etc. - - ✅︎ * - :code:`FuyuForCausalLM` - Fuyu - - Image + - T + I - :code:`adept/fuyu-8b` etc. - - ✅︎ * - :code:`ChatGLMModel` - GLM-4V - - Image + - T + I - :code:`THUDM/glm-4v-9b` etc. - - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - - ✅︎ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - ✅︎ * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - - Video + - T + V - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - - ✅︎ * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - - Image\ :sup:`+` / Video + - T + I\ :sup:`+` + V - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - ✅︎ * - :code:`MiniCPMV` - MiniCPM-V - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - ✅︎ - ✅︎ * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - - Image + - T + I - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - @@ -407,43 +414,43 @@ Text Generation - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`nvidia/NVLM-D-72B`, etc. - - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - - Image\ :sup:`E` + - T + I\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - - ✅︎ * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - ✅︎ * - :code:`PixtralForConditionalGeneration` - Pixtral - - Image\ :sup:`+` + - T + I\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409` - - ✅︎ * - :code:`QWenLMHeadModel` - Qwen-VL - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - - ✅︎ * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - - Image\ :sup:`E+` / Video\ :sup:`+` + - T + I\ :sup:`E+` + V\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - - ✅︎ * - :code:`UltravoxModel` - Ultravox - - Audio\ :sup:`E+` + - T + A\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - ✅︎ @@ -455,6 +462,26 @@ Text Generation For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 +Multimodal Embedding +-------------------- + +.. list-table:: + :widths: 25 25 15 25 5 5 + :header-rows: 1 + + * - Architecture + - Models + - Inputs + - Example HF Models + - :ref:`LoRA ` + - :ref:`PP ` + * - :code:`Phi3VForCausalLM` + - Phi-3-Vision-based + - T + I + - :code:`TIGER-Lab/VLM2Vec-Full` + - 🚧 + - ✅︎ + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/examples/offline_inference_vision_language_embedding.py b/examples/offline_inference_vision_language_embedding.py new file mode 100644 index 0000000000000..8e62199e1db7b --- /dev/null +++ b/examples/offline_inference_vision_language_embedding.py @@ -0,0 +1,21 @@ +from vllm import LLM +from vllm.assets.image import ImageAsset + +image = ImageAsset("cherry_blossom").pil_image.convert("RGB") +prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501 + +# Create an LLM. +llm = LLM( + model="TIGER-Lab/VLM2Vec-Full", + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={"num_crops": 16}, +) + +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}}) + +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 3072 floats diff --git a/tests/conftest.py b/tests/conftest.py index baa6bae03a451..5df7da9ee64e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -262,7 +262,7 @@ def __init__( dtype: str = "half", *, model_kwargs: Optional[Dict[str, Any]] = None, - is_embedding_model: bool = False, + is_sentence_transformer: bool = False, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, @@ -271,7 +271,7 @@ def __init__( self.model_name = model_name - if is_embedding_model: + if is_sentence_transformer: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = self.wrap_device( @@ -307,17 +307,23 @@ def __init__( self.postprocess_inputs = postprocess_inputs - def generate( + def get_inputs( self, prompts: List[str], images: Optional[PromptImageInput] = None, - videos: Optional[List[np.ndarray]] = None, - **kwargs: Any, - ) -> List[Tuple[List[List[int]], List[str]]]: - if images: + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[BatchEncoding]: + if images is not None: assert len(prompts) == len(images) - outputs: List[Tuple[List[List[int]], List[str]]] = [] + if videos is not None: + assert len(prompts) == len(videos) + + if audios is not None: + assert len(prompts) == len(audios) + + all_inputs: List[BatchEncoding] = [] for i, prompt in enumerate(prompts): processor_kwargs: Dict[str, Any] = { "text": prompt, @@ -327,10 +333,33 @@ def generate( processor_kwargs["images"] = images[i] if videos is not None and videos[i] is not None: processor_kwargs["videos"] = videos[i] + if audios is not None and audios[i] is not None: + audio, sr = audios[i] + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr inputs = self.processor(**processor_kwargs) inputs = self.postprocess_inputs(inputs) + all_inputs.append(inputs) + + return all_inputs + + def generate( + self, + prompts: List[str], + images: Optional[PromptImageInput] = None, + videos: Optional[List[np.ndarray]] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> List[Tuple[List[List[int]], List[str]]]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for inputs in all_inputs: output_ids = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, @@ -350,12 +379,16 @@ def generate_greedy( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, + videos: Optional[List[np.ndarray]] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, images=images, + videos=videos, + audios=audios, **kwargs) return [(output_ids[0], output_str[0]) @@ -388,22 +421,16 @@ def generate_greedy_logprobs( max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[List[np.ndarray]] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[List[torch.Tensor]]: - all_logprobs: List[List[torch.Tensor]] = [] - for i, prompt in enumerate(prompts): - processor_kwargs: Dict[str, Any] = { - "text": prompt, - "return_tensors": "pt", - } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - if videos is not None and videos[i] is not None: - processor_kwargs["videos"] = videos[i] - - inputs = self.processor(**processor_kwargs) - inputs = self.postprocess_inputs(inputs) + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + all_logprobs: List[List[torch.Tensor]] = [] + for inputs in all_inputs: output = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, @@ -475,28 +502,16 @@ def generate_greedy_logprobs_limit( videos: Optional[List[np.ndarray]] = None, **kwargs: Any, ) -> List[TokensTextLogprobs]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] - for i, prompt in enumerate(prompts): - processor_kwargs: Dict[str, Any] = { - "text": prompt, - "return_tensors": "pt", - } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - - if audios is not None: - audio, sr = audios[i] - processor_kwargs["audio"] = audio - processor_kwargs["sampling_rate"] = sr - - if videos is not None: - processor_kwargs["videos"] = videos[i] - inputs = self.processor(**processor_kwargs) - inputs = self.postprocess_inputs(inputs) - + for inputs in all_inputs: output = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, @@ -632,20 +647,50 @@ def __init__( **kwargs, ) - def generate( + def get_inputs( self, prompts: List[str], - sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, - ) -> List[Tuple[List[List[int]], List[str]]]: + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[TextPrompt]: if images is not None: assert len(prompts) == len(images) + if videos is not None: + assert len(prompts) == len(videos) + + if audios is not None: + assert len(prompts) == len(audios) + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] if images is not None: for i, image in enumerate(images): inputs[i]["multi_modal_data"] = {"image": image} + if videos is not None: + for i, video in enumerate(videos): + inputs[i]["multi_modal_data"] = {"video": video} + + if audios is not None: + for i, audio in enumerate(audios): + inputs[i]["multi_modal_data"] = {"audio": audio} + + return inputs + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) @@ -687,24 +732,10 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - if images is not None: - assert len(prompts) == len(images) - - if videos is not None: - assert len(prompts) == len(videos) - - inputs = [TextPrompt(prompt=prompt) for prompt in prompts] - if images is not None: - for i, image in enumerate(images): - inputs[i]["multi_modal_data"] = {"image": image} - - if audios is not None: - for i, audio in enumerate(audios): - inputs[i]["multi_modal_data"] = {"audio": audio} - - if videos is not None: - for i, video in enumerate(videos): - inputs[i]["multi_modal_data"] = {"video": video} + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) req_outputs = self.model.generate(inputs, sampling_params=sampling_params) @@ -741,9 +772,15 @@ def generate_greedy( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, greedy_params, images=images) + outputs = self.generate(prompts, + greedy_params, + images=images, + videos=videos, + audios=audios) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index be316c6e12da1..5f704d854e5dc 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -1,10 +1,10 @@ -"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. +"""Compare the embedding outputs of HF and vLLM models. Run `pytest tests/models/embedding/language/test_embedding.py`. """ import pytest -import torch -import torch.nn.functional as F + +from ..utils import check_embeddings_close MODELS = [ "intfloat/e5-mistral-7b-instruct", @@ -12,14 +12,6 @@ ] -def compare_embeddings(embeddings1, embeddings2): - similarities = [ - F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) - for e1, e2 in zip(embeddings1, embeddings2) - ] - return similarities - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models( @@ -37,15 +29,17 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) - similarities = compare_embeddings(hf_outputs, vllm_outputs) - all_similarities = torch.stack(similarities) - tolerance = 1e-2 - assert torch.all((all_similarities <= 1.0 + tolerance) - & (all_similarities >= 1.0 - tolerance) - ), f"Not all values are within {tolerance} of 1.0" + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py new file mode 100644 index 0000000000000..2fcc2013d91ef --- /dev/null +++ b/tests/models/embedding/utils.py @@ -0,0 +1,29 @@ +from typing import List, Sequence + +import torch +import torch.nn.functional as F + + +def check_embeddings_close( + *, + embeddings_0_lst: Sequence[List[float]], + embeddings_1_lst: Sequence[List[float]], + name_0: str, + name_1: str, + tol: float = 1e-3, +) -> None: + assert len(embeddings_0_lst) == len(embeddings_1_lst) + + for prompt_idx, (embeddings_0, embeddings_1) in enumerate( + zip(embeddings_0_lst, embeddings_1_lst)): + assert len(embeddings_0) == len(embeddings_1) + + sim = F.cosine_similarity(torch.tensor(embeddings_0), + torch.tensor(embeddings_1), + dim=0) + + fail_msg = (f"Test{prompt_idx}:" + f"\n{name_0}:\t{embeddings_0!r}" + f"\n{name_1}:\t{embeddings_1!r}") + + assert sim >= 1 - tol, fail_msg diff --git a/tests/models/embedding/vision_language/__init__.py b/tests/models/embedding/vision_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py new file mode 100644 index 0000000000000..ea6b56cd02625 --- /dev/null +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -0,0 +1,62 @@ +import pytest +import torch.nn.functional as F + +from ....conftest import IMAGE_ASSETS +from ..utils import check_embeddings_close + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 + "cherry_blossom": + "<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501 +}) + +MODELS = ["TIGER-Lab/VLM2Vec-Full"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + max_model_len=4096, + max_num_seqs=2, + dtype=dtype, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner(model, dtype=dtype) as hf_model: + all_inputs = hf_model.get_inputs(example_prompts) + + all_outputs = [] + for inputs in all_inputs: + # Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py + outputs = hf_model.model( + **hf_model.wrap_device(inputs, + device=hf_model.model.device.type), + return_dict=True, + output_hidden_states=True, + ) + last_hidden_state = outputs.hidden_states[-1][0] + reps = last_hidden_state[inputs.attention_mask[0].sum() - 1] + pooled_output = F.normalize(reps, p=2, dim=-1) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 1d61f6b74f520..21958b1640204 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -3,7 +3,7 @@ import torch from vllm.attention import AttentionMetadata -from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel +from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel from vllm.sequence import IntermediateTensors diff --git a/vllm/config.py b/vllm/config.py index 33005ebbd5219..614cacd51fb27 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -237,7 +237,16 @@ def _verify_tokenizer_mode(self) -> None: def _verify_embedding_mode(self) -> None: architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = ModelRegistry.is_embedding_model(architectures) + + # TODO: Allow the same model architecture to be specified as either + # generation or embedding model + if "Phi3VForCausalLM" in architectures: + # Match both remote and local names + embedding_mode = "/VLM2Vec" in self.model + else: + embedding_mode = ModelRegistry.is_embedding_model(architectures) + + self.embedding_mode = embedding_mode def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index bcb03ef55ef94..f958268741cd5 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,14 +31,16 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, @@ -461,3 +463,50 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.tie_word_embeddings else None), ) loader.load_weights(weights) + + +class Gemma2EmbeddingModel(nn.Module, SupportsPP): + """ + A model that uses Gemma2 with additional embedding functionalities. + + This class encapsulates the Gemma2Model and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of Gemma2Model used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + + self.model = Gemma2Model(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/gemma2_embedding.py b/vllm/model_executor/models/gemma2_embedding.py deleted file mode 100644 index e8e10598c1644..0000000000000 --- a/vllm/model_executor/models/gemma2_embedding.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Iterable, List, Optional, Tuple, Union - -import torch -from torch import nn - -from vllm.attention import AttentionMetadata -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput - -from .gemma2 import Gemma2Model -from .interfaces import SupportsPP - - -class Gemma2EmbeddingModel(nn.Module, SupportsPP): - """A model that uses Gemma2 with additional embedding functionalities. - - This class encapsulates the Gemma2Model and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of Gemma2Model used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__() - self.model = Gemma2Model(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ad5cfcc44022f..fd88ae8b50402 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,6 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) @@ -47,8 +48,9 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP @@ -615,3 +617,52 @@ def permute(w: torch.Tensor, n_heads: int): name = name.replace(item, mapping[item]) return name, loaded_weight + + +class LlamaEmbeddingModel(nn.Module, SupportsPP): + """ + A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.model.load_weights(weights) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py deleted file mode 100644 index 13574e84d7aa2..0000000000000 --- a/vllm/model_executor/models/llama_embedding.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Iterable, List, Optional, Tuple, Union - -import torch -from torch import nn - -from vllm.attention import AttentionMetadata -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput - -from .interfaces import SupportsPP -from .llama import LlamaModel - - -class LlamaEmbeddingModel(nn.Module, SupportsPP): - """A model that uses Llama with additional embedding functionalities. - - This class encapsulates the LlamaModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of LlamaModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__() - self.model = LlamaModel(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - self.model.load_weights(weights) - - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 00a04dac88789..bcd5cd2154e66 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,14 +29,18 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip @@ -289,10 +293,6 @@ def add_image_newline(self, image_features_hd): dim=2).reshape(num_images, -1, hid_dim) return image_features_hd_newline - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) - loader.load_weights(weights) - # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): @@ -385,23 +385,28 @@ def dummy_data_for_phi3v(ctx: InputContext, return seq_data, mm_data -# Reserve this function to also handle placeholders for additional images -# [ref: PR #5820] @lru_cache -def _get_image_placeholder_token_ids(model_config: ModelConfig, - idx: int) -> List[int]: +def _get_image_placeholder_token_id_candidates( + model_config: ModelConfig, + idx: int, +) -> List[List[int]]: assert idx > 0 tokenizer = cached_get_tokenizer(model_config.tokenizer) + # This is used when the image token is at the start of the string + start_candidate = tokenizer.encode(f"<|image_{idx}|>", + add_special_tokens=False) + + # This is used when the image token is in the middle of the string # We need to get the token for "<", not "▁<" # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json a_token_id, = tokenizer.encode("a", add_special_tokens=False) - a_token_id_, *image_placeholder_token_ids = tokenizer.encode( - f"a<|image_{idx}|>", add_special_tokens=False) + a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>", + add_special_tokens=False) assert a_token_id == a_token_id_ - return image_placeholder_token_ids + return [start_candidate, middle_candidate] def input_processor_for_phi3v(ctx: InputContext, @@ -461,16 +466,20 @@ def input_processor_for_phi3v(ctx: InputContext, prompt_token_ids = llm_inputs["prompt_token_ids"].copy() - # masked place_holder with image token id + print("prompt_token_ids (old)", prompt_token_ids) + + # masked placeholder with image token id for idx in image_idx: - image_token_ids = _get_image_placeholder_token_ids(model_config, - idx=idx) - for i in range(len(prompt_token_ids) - len(image_token_ids) + 1): - if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids: - prompt_token_ids[i:i + len(image_token_ids)] = [ - _IMAGE_TOKEN_ID - ] * len(image_token_ids) - break + candidates = _get_image_placeholder_token_id_candidates(model_config, + idx=idx) + + for candidate in candidates: + for i in range(len(prompt_token_ids) - len(candidate) + 1): + if prompt_token_ids[i:i + len(candidate)] == candidate: + prompt_token_ids[i:i + + len(candidate)] = ([_IMAGE_TOKEN_ID] * + len(candidate)) + break # merge consecutive tag ids merged_token_ids: List[int] = [] @@ -520,12 +529,23 @@ def __init__(self, self.multimodal_config = multimodal_config self.image_token_id = _IMAGE_TOKEN_ID - # TODO: Optionally initializes this for supporting embeddings. + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + + # TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding(config) self.language_model = LlamaForCausalLM(config, cache_config, quant_config) + # The same model class supports both language generation and embedding + # because the architecture name is the same + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -649,8 +669,7 @@ def forward(self, if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.image_token_id) @@ -682,13 +701,27 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ + "model.vision_embed_tokens.wte": "embed_tokens", "model.vision_embed_tokens.": "vision_embed_tokens.", "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", }) loader = AutoWeightsLoader(self) - loader.load_weights(weights, mapper=hf_to_vllm_mapper) + autoloaded_weights = loader.load_weights(weights, + mapper=hf_to_vllm_mapper) + + # The HF config doesn't specify whether these are tied, + # so we detect it this way + if "embed_tokens" not in autoloaded_weights: + self.embed_tokens = self.language_model.model.embed_tokens diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b06d3d612dbcc..03a67e3712d72 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -86,9 +86,12 @@ } _EMBEDDING_MODELS = { - "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + # [Text-only] + "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), + "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), - "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), + # [Multimodal] + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), } _MULTIMODAL_MODELS = { diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 89b64ba2fd43c..8aac9c0eb3a0e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -124,7 +124,7 @@ def _load_param( base_prefix: str, param: nn.Parameter, weights: Iterable[Tuple[str, torch.Tensor]], - ) -> None: + ) -> Iterable[str]: for weight_name, weight_data in weights: weight_qualname = self._get_qualname(base_prefix, weight_name) @@ -143,12 +143,14 @@ def _load_param( default_weight_loader) weight_loader(param, weight_data) + yield weight_qualname + def _load_module( self, base_prefix: str, module: nn.Module, weights: Iterable[Tuple[str, torch.Tensor]], - ) -> None: + ) -> Iterable[str]: if isinstance(module, PPMissingLayer): return @@ -170,14 +172,16 @@ def _load_module( continue if child_prefix in child_modules: - self._load_module(prefix, child_modules[child_prefix], - child_weights) + yield from self._load_module(prefix, + child_modules[child_prefix], + child_weights) elif child_prefix in child_params: - self._load_param(prefix, child_params[child_prefix], - child_weights) + yield from self._load_param(prefix, child_params[child_prefix], + child_weights) else: if not self._can_ignore_unexpected(prefix): - msg = f"There is no module or parameter named '{prefix}'" + msg = (f"There is no module or parameter named '{prefix}' " + f"in {type(self.module).__name__}") raise ValueError(msg) def load_weights( @@ -185,11 +189,12 @@ def load_weights( weights: Iterable[Tuple[str, torch.Tensor]], *, mapper: Optional[WeightsMapper] = None, - ) -> None: + ) -> List[str]: if mapper is not None: weights = mapper.apply(weights) - self._load_module("", self.module, weights) + autoloaded_weights = list(self._load_module("", self.module, weights)) + return autoloaded_weights def init_vllm_registered_model(