diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..9dc18adfb8 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -348,6 +348,7 @@ from .models.cohere import CohereForCausalLM, CohereModel, CoherePreTrainedModel from .models.cohere2 import Cohere2ForCausalLM, Cohere2Model, Cohere2PreTrainedModel from .models.colpali import ColPaliForRetrieval, ColPaliPreTrainedModel, ColPaliProcessor +from .models.colqwen2 import ColQwen2ForRetrieval, ColQwen2PreTrainedModel, ColQwen2Processor from .models.convbert import ( ConvBertForMaskedLM, ConvBertForMultipleChoice, diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 7fc96f444f..8fd90631b6 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -1445,9 +1445,10 @@ def _initialize_weights(self, module): self._init_weights(module) module._is_hf_initialized = True - def tie_weights(self): + def tie_embeddings_and_encoder_decoder(self): """ - Tie the weights between the input embeddings and the output embeddings. + If set in the config, tie the weights between the input embeddings and the output embeddings, + and the encoder and decoder. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the weights instead. @@ -1468,7 +1469,16 @@ def tie_weights(self): # Leading to issues on subsequent calls by different tests or subsequent calls. self._dynamic_tied_weights_keys = tied_weights + def tie_weights(self): + """ + Recursively (for all submodels) tie all the weights of the model. + """ + # Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call for name, module in self.cells_and_names(): + # If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights + if isinstance(module, PreTrainedModel): + module.tie_embeddings_and_encoder_decoder() + # Additionally, if it has a custom `_tie_weights`, honor it if hasattr(module, "_tie_weights"): module._tie_weights() diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 28e7ced270..da3f49af0b 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -49,6 +49,7 @@ cohere, cohere2, colpali, + colqwen2, convbert, convnext, convnextv2, diff --git a/mindone/transformers/models/auto/processing_auto.py b/mindone/transformers/models/auto/processing_auto.py index 88df3d8aee..0a4ff117d6 100644 --- a/mindone/transformers/models/auto/processing_auto.py +++ b/mindone/transformers/models/auto/processing_auto.py @@ -55,6 +55,7 @@ ("chameleon", "ChameleonProcessor"), ("chinese_clip", "ChineseCLIPProcessor"), ("colpali", "ColPaliProcessor"), + ("colqwen2", "ColQwen2Processor"), ("flava", "FlavaProcessor"), ("idefics", "IdeficsProcessor"), ("instructblip", "InstructBlipProcessor"), diff --git a/mindone/transformers/models/colqwen2/__init__.py b/mindone/transformers/models/colqwen2/__init__.py new file mode 100644 index 0000000000..b24c24be70 --- /dev/null +++ b/mindone/transformers/models/colqwen2/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2PreTrainedModel +from .processing_colqwen2 import ColQwen2Processor diff --git a/mindone/transformers/models/colqwen2/modeling_colqwen2.py b/mindone/transformers/models/colqwen2/modeling_colqwen2.py new file mode 100644 index 0000000000..0d942a7908 --- /dev/null +++ b/mindone/transformers/models/colqwen2/modeling_colqwen2.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +from transformers import ColQwen2Config +from transformers.utils import ModelOutput, auto_docstring, can_return_tuple + +import mindspore as ms +from mindspore import mint + +from ...cache_utils import Cache +from ...modeling_utils import PreTrainedModel +from ..auto import AutoModelForImageTextToText + + +@auto_docstring +class ColQwen2PreTrainedModel(PreTrainedModel): + config: ColQwen2Config + base_model_prefix = "model" + _no_split_modules = [] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.vlm_config.text_config.initializer_range + ) + + if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)): + from mindone.models.utils import normal_, zeros_ + + normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + zeros_(module.bias) + elif isinstance(module, mint.nn.Embedding): + from mindone.models.utils import normal_, zeros_ + + # Embedding uses `embedding_table` in MS nn.Embedding + normal_(module.weight, mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for ColQwen2 embeddings output. + """ +) +class ColQwen2ForRetrievalOutput(ModelOutput): + r""" + loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + embeddings (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + The embeddings of the model. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[ms.Tensor] = None + embeddings: Optional[ms.Tensor] = None + past_key_values: Optional[Union[list[ms.Tensor], Cache]] = None + hidden_states: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[ms.Tensor]] = None + + +@auto_docstring( + custom_intro=""" + Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly + from document images (β€œscreenshots”) for document retrieval. The model is trained to maximize the similarity + between these document embeddings and the corresponding query embeddings, using the late interaction method + introduced in ColBERT. + + Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with + a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. + + ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper: + [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449). + """ +) +class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: ColQwen2Config): + super().__init__(config) + self.config = config + self.vocab_size = config.vlm_config.text_config.vocab_size + + self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config) + + self.embedding_dim = self.config.embedding_dim + self.embedding_proj_layer = mint.nn.Linear( + self.config.vlm_config.text_config.hidden_size, + self.embedding_dim, + ) + self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] + + self.post_init() + + @can_return_tuple + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[ms.Tensor] = None, + image_grid_thw: Optional[ms.Tensor] = None, + cache_position: Optional[ms.Tensor] = None, + ) -> ColQwen2ForRetrievalOutput: + r""" + image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values) + + # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding + if pixel_values is not None and image_grid_thw is not None: + # NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size) + offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w) + pixel_values = mint.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)], + dim=0, + ) # (num_patches_h * num_patches_w, pixel_values) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + position_ids, rope_deltas = self.vlm.model.get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=attention_mask, + ) + + # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. + if inputs_embeds is None: + inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) + + if pixel_values is not None: + pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) + image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).broadcast_to(inputs_embeds.shape) + ) + image_embeds = image_embeds.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if attention_mask is not None: + pass + + vlm_output = self.vlm.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None + + last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size) + embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + + return ColQwen2ForRetrievalOutput( + embeddings=embeddings, + past_key_values=vlm_output.past_key_values, + hidden_states=vlm_hidden_states, + attentions=vlm_output.attentions, + ) + + def get_input_embeddings(self): + return self.vlm.get_input_embeddings() + + def set_input_embeddings(self, value): + self.vlm.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.vlm.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.vlm.set_output_embeddings(new_embeddings) + + def tie_weights(self): + return self.vlm.tie_weights() + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> mint.nn.Embedding: + model_embeds = self.vlm.resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + + self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vlm_config.vocab_size = model_embeds.num_embeddings + self.vlm.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + + return model_embeds + + +__all__ = ["ColQwen2ForRetrieval", "ColQwen2PreTrainedModel"] diff --git a/mindone/transformers/models/colqwen2/processing_colqwen2.py b/mindone/transformers/models/colqwen2/processing_colqwen2.py new file mode 100644 index 0000000000..2d3fcd36d3 --- /dev/null +++ b/mindone/transformers/models/colqwen2/processing_colqwen2.py @@ -0,0 +1,432 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_colqwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +import mindspore as ms +from mindspore import mint + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack + + +class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "longest", + }, + "images_kwargs": { + "data_format": "channels_first", + "do_convert_rgb": True, + }, + "common_kwargs": {"return_tensors": "np"}, + } + + +class ColQwen2Processor(ProcessorMixin): + r""" + Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as + well as to compute the late-interaction retrieval score. + + [`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`] + for more information. + + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens. + query_prefix (`str`, *optional*): A prefix to be used for the query. + """ + + attributes = ["image_processor", "tokenizer"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + visual_prompt_prefix: Optional[str] = None, + query_prefix: Optional[str] = None, + **kwargs, + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + + if visual_prompt_prefix is None: + visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + self.visual_prompt_prefix = visual_prompt_prefix + + if query_prefix is None: + query_prefix = "Query: " + self.query_prefix = query_prefix + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom + wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process + both text and images at the same time. + + When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's + [`~Qwen2TokenizerFast.__call__`]. + When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's + [`~Qwen2VLImageProcessor.__call__`]. + Please refer to the doctsring of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[ms.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or MindSpore + tensor. In case of a NumPy array/MindSpore tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'np'`: Return NumPy `np.ndarray` objects. + - `'ms'`: Return MindSpore `ms.Tensor` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + ColQwen2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = True if suffix is not None else False + + if text is None and images is None: + raise ValueError("Either text or images must be provided") + if text is not None and images is not None: + raise ValueError("Only one of text or images can be processed at a time") + + if images is not None: + if is_valid_image(images): + images = [images] + elif isinstance(images, list) and is_valid_image(images[0]): + pass + elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): + raise ValueError("images must be an image, list of images or list of list of images") + + texts_doc = [self.visual_prompt_prefix] * len(images) + + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(texts_doc)): + while self.image_token in texts_doc[i]: + texts_doc[i] = texts_doc[i].replace( + self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1 + ) + index += 1 + texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token) + + text_inputs = self.tokenizer( + texts_doc, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return_data = BatchFeature(data={**text_inputs, **image_inputs}) + + # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs. + offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,) + + # Split the pixel_values tensor into a list of tensors, one per image + pixel_values = list( + mint.split(ms.Tensor(return_data["pixel_values"]), offsets.tolist()) + ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)] + + # Pad the list of pixel_value tensors to the same length along the sequence dimension + # Implement pad_sequence equivalent in MindSpore + max_len = max(t.shape[0] for t in pixel_values) + feat_dim = pixel_values[0].shape[1] + batch = [] + for t in pixel_values: + pad_len = max_len - t.shape[0] + if pad_len > 0: + pad = mint.zeros((pad_len, feat_dim), dtype=t.dtype) + batch.append(mint.cat([t, pad], dim=0)) + else: + batch.append(t) + return_data["pixel_values"] = mint.stack(batch, dim=0) # (batch_size, max_num_patches, pixel_values) + + if return_token_type_ids: + labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + + for item in return_data: + if not isinstance(return_data[item], ms.Tensor): + return_data[item] = ms.Tensor(return_data[item]) + return return_data + + elif text is not None: + if isinstance(text, str): + text = [text] + elif not (isinstance(text, list) and isinstance(text[0], str)): + raise ValueError("Text must be a string or a list of strings") + + if suffix is None: + suffix = self.query_augmentation_token * 10 + + texts_query: list[str] = [] + + for query in text: + augmented_query = self.query_prefix + query + suffix + texts_query.append(augmented_query) + + batch_query = self.tokenizer( + texts_query, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return batch_query + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (list[list[str]], *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + dict[str, list[int]]: A dictionary mapping each modality ("image", "video", "audio") + to a list containing the number of placeholder tokens required. If the model doesn't accept + a certain modality or no input sizes are provided, the dict value is set to an empty list. + """ + vision_data = {} + if image_sizes is not None: + num_image_tokens = [self.image_seq_length] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + return MultiModalData(**vision_data) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + @property + def query_augmentation_token(self) -> str: + """ + Return the query augmentation token. + + Query augmentation buffers are used as reasoning buffers during inference. + """ + return self.tokenizer.pad_token + + def process_images( + self, + images: ImageInput = None, + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's + [`ColQwen2Processor.__call__`]. + + This method forwards the `images` and `kwargs` arguments to the image processor. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[ms.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or MindSpore + tensor. In case of a NumPy array/MindSpore tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'ms'`: Return MindSpore `ms.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + return self.__call__(images=images, **kwargs) + + def process_queries( + self, + text: Union[TextInput, list[TextInput]], + **kwargs: Unpack[ColQwen2ProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's + [`ColQwen2Processor.__call__`]. + + This method forwards the `text` and `kwargs` arguments to the tokenizer. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'ms'`: Return MindSpore `ms.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + """ + return self.__call__(text=text, **kwargs) + + def score_retrieval( + self, + query_embeddings: Union[ms.Tensor, list[ms.Tensor]], + passage_embeddings: Union[ms.Tensor, list[ms.Tensor]], + batch_size: int = 128, + output_dtype: Optional[ms.Type] = None, + output_device: Union[str, None] = None, + ) -> ms.Tensor: + """ + Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector + query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the + image of a document page. + + Because the embedding tensors are multi-vector and can thus have different shapes, they + should be fed as: + (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) + (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually + obtained by padding the list of tensors. + + Args: + query_embeddings (`Union[ms.Tensor, list[ms.Tensor]]`): Query embeddings. + passage_embeddings (`Union[ms.Tensor, list[ms.Tensor]]`): Passage embeddings. + batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. + output_dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): The dtype of the output tensor. + If `None`, the dtype of the input embeddings is used. + output_device (`str`, *optional*): Unused in MindSpore; kept for API compatibility. + + Returns: + `ms.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. + """ + + if len(query_embeddings) == 0: + raise ValueError("No queries provided") + if len(passage_embeddings) == 0: + raise ValueError("No passages provided") + + if query_embeddings[0].dtype != passage_embeddings[0].dtype: + raise ValueError("Queries and passages must have the same dtype") + + if output_dtype is None: + output_dtype = query_embeddings[0].dtype + + scores: list[ms.Tensor] = [] + + for i in range(0, len(query_embeddings), batch_size): + batch_scores: list[ms.Tensor] = [] + # pad queries + q_chunk = query_embeddings[i : i + batch_size] + q_max = max(t.shape[0] for t in q_chunk) + q_feat = q_chunk[0].shape[1] + q_batch = [] + for t in q_chunk: + pad = q_max - t.shape[0] + if pad > 0: + q_batch.append(mint.cat([t, mint.zeros((pad, q_feat), dtype=t.dtype)], dim=0)) + else: + q_batch.append(t) + batch_queries = mint.stack(q_batch, dim=0) + for j in range(0, len(passage_embeddings), batch_size): + # pad passages + p_chunk = passage_embeddings[j : j + batch_size] + p_max = max(t.shape[0] for t in p_chunk) + p_feat = p_chunk[0].shape[1] + p_batch = [] + for t in p_chunk: + pad = p_max - t.shape[0] + if pad > 0: + p_batch.append(mint.cat([t, mint.zeros((pad, p_feat), dtype=t.dtype)], dim=0)) + else: + p_batch.append(t) + batch_passages = mint.stack(p_batch, dim=0) + batch_scores.append( + mint.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) + ) + cur = mint.cat(batch_scores, dim=1).to(output_dtype) + scores.append(cur) + + return mint.cat(scores, dim=0) + + +__all__ = ["ColQwen2Processor"] diff --git a/mindone/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/mindone/transformers/models/qwen2_vl/modeling_qwen2_vl.py index caf6a8472f..0d4ec62b93 100644 --- a/mindone/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/mindone/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1157,16 +1157,10 @@ def __init__(self, config: Qwen2VLConfig): self.post_init() def get_input_embeddings(self): - return self.language_model.embed_tokens + return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): - self.language_model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder @@ -1333,6 +1327,78 @@ def get_rope_index( return position_ids, mrope_position_deltas + def get_video_features(self, pixel_values_videos: ms.Tensor, video_grid_thw: Optional[ms.Tensor] = None): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`ms.Tensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`ms.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = (mint.prod(video_grid_thw, dim=-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = mint.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features(self, pixel_values: ms.Tensor, image_grid_thw: Optional[ms.Tensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`ms.Tensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (mint.prod(image_grid_thw, dim=-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = mint.split(image_embeds, split_sizes) + return image_embeds + + def get_placeholder_mask( + self, + input_ids: ms.Tensor, + inputs_embeds: ms.Tensor, + image_features: Optional[ms.Tensor] = None, + video_features: Optional[ms.Tensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + ms.Tensor(self.config.image_token_id, dtype=ms.int64) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + ms.Tensor(self.config.video_token_id, dtype=ms.int64) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + def construct( self, input_ids: ms.Tensor = None, diff --git a/tests/transformers_tests/models/colqwen2/__init__.py b/tests/transformers_tests/models/colqwen2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/colqwen2/test_modeling_colqwen2.py b/tests/transformers_tests/models/colqwen2/test_modeling_colqwen2.py new file mode 100644 index 0000000000..bee1c1dca9 --- /dev/null +++ b/tests/transformers_tests/models/colqwen2/test_modeling_colqwen2.py @@ -0,0 +1,257 @@ +"""Adapted from https://github.com/huggingface/transformers/tree/main/tests/models/colqwen2/test_modeling_colqwen2.py.""" + +# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map]. +# +# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective +# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically +# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. + +import inspect + +import numpy as np +import pytest +import torch +from transformers import ColQwen2Config + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] # 1: pynative mode (graph mode not supported yet) + + +class ColQwen2ModelTester: + config_class = ColQwen2Config + + def __init__( + self, + batch_size=2, + seq_length=7, + image_size=224, + num_channels=3, + is_training=False, + use_input_mask=True, + use_labels=True, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + # VLM config (Qwen2VLConfig with reduced size) + vocab_size=99, + hidden_size=128, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=8, + bos_token_id=0, + eos_token_id=1, + pad_token_id=2, + vision_start_token_id=3, + image_token_id=4, + video_token_id=5, + vision_config={ + "depth": 2, + "in_chans": 3, + "hidden_act": "silu", + "intermediate_size": 32, + "out_hidden_size": 128, + "hidden_size": 128, + "num_heads": 8, + "patch_size": 14, + "spatial_patch_size": 14, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + }, + # ColQwen2 specific + embedding_dim=64, + initializer_range=0.02, + ): + self.batch_size = batch_size + self.seq_length = seq_length + self.image_size = image_size + self.num_channels = num_channels + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + + # Qwen2VL config parameters (reduced for testing) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rope_scaling = { + "mrope_section": [2, 3, 3], + "type": "mrope", + } # sum*2=16 = head_dim = hidden_size//num_attention_heads = 128//8=16 + self.vision_config = vision_config + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.vision_start_token_id = vision_start_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.num_image_tokens = 32 + self.seq_length = self.seq_length + self.num_image_tokens + # ColQwen2 specific + self.embedding_dim = embedding_dim + self.initializer_range = initializer_range + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = self.vision_config["patch_size"] + temporal_patch_size = self.vision_config["temporal_patch_size"] + pixel_values = floats_numpy( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = np.ones(input_ids.shape, dtype=np.int64) + + input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.video_token_id] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id + input_ids[:, self.num_image_tokens] = self.image_token_id + input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id + inputs_dict = { + # "pixel_values": pixel_values, + "image_grid_thw": np.array([[1, 1, 1]] * self.batch_size), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def get_config(self): + from transformers.models.qwen2_vl import Qwen2VLConfig + + vlm_config = Qwen2VLConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rope_scaling=self.rope_scaling, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + vision_start_token_id=self.vision_start_token_id, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_config=self.vision_config, + ) + + return self.config_class( + vlm_config=vlm_config, + embedding_dim=self.embedding_dim, + initializer_range=self.initializer_range, + ) + + +model_tester = ColQwen2ModelTester() +config, inputs_dict = model_tester.prepare_config_and_inputs_for_common() + +COLQWEN2_CASES = [ + [ + "ColQwen2ForRetrieval", + "transformers.ColQwen2ForRetrieval", + "mindone.transformers.ColQwen2ForRetrieval", + (config,), + {}, + (), + inputs_dict, + {"embeddings": "embeddings"}, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in COLQWEN2_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )