Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@
from .models.cohere import CohereForCausalLM, CohereModel, CoherePreTrainedModel
from .models.cohere2 import Cohere2ForCausalLM, Cohere2Model, Cohere2PreTrainedModel
from .models.colpali import ColPaliForRetrieval, ColPaliPreTrainedModel, ColPaliProcessor
from .models.colqwen2 import ColQwen2ForRetrieval, ColQwen2PreTrainedModel, ColQwen2Processor
from .models.convbert import (
ConvBertForMaskedLM,
ConvBertForMultipleChoice,
Expand Down
14 changes: 12 additions & 2 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,9 +1445,10 @@ def _initialize_weights(self, module):
self._init_weights(module)
module._is_hf_initialized = True

def tie_weights(self):
def tie_embeddings_and_encoder_decoder(self):
"""
Tie the weights between the input embeddings and the output embeddings.
If set in the config, tie the weights between the input embeddings and the output embeddings,
and the encoder and decoder.

If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
Expand All @@ -1468,7 +1469,16 @@ def tie_weights(self):
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights

def tie_weights(self):
"""
Recursively (for all submodels) tie all the weights of the model.
"""
# Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call
for name, module in self.cells_and_names():
# If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights
if isinstance(module, PreTrainedModel):
module.tie_embeddings_and_encoder_decoder()
# Additionally, if it has a custom `_tie_weights`, honor it
if hasattr(module, "_tie_weights"):
module._tie_weights()

Expand Down
1 change: 1 addition & 0 deletions mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
cohere,
cohere2,
colpali,
colqwen2,
convbert,
convnext,
convnextv2,
Expand Down
1 change: 1 addition & 0 deletions mindone/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
("chameleon", "ChameleonProcessor"),
("chinese_clip", "ChineseCLIPProcessor"),
("colpali", "ColPaliProcessor"),
("colqwen2", "ColQwen2Processor"),
("flava", "FlavaProcessor"),
("idefics", "IdeficsProcessor"),
("instructblip", "InstructBlipProcessor"),
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/colqwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2PreTrainedModel
from .processing_colqwen2 import ColQwen2Processor
250 changes: 250 additions & 0 deletions mindone/transformers/models/colqwen2/modeling_colqwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Union

from transformers import ColQwen2Config
from transformers.utils import ModelOutput, auto_docstring, can_return_tuple

import mindspore as ms
from mindspore import mint

from ...cache_utils import Cache
from ...modeling_utils import PreTrainedModel
from ..auto import AutoModelForImageTextToText


@auto_docstring
class ColQwen2PreTrainedModel(PreTrainedModel):
config: ColQwen2Config
base_model_prefix = "model"
_no_split_modules = []
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True

def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.vlm_config.text_config.initializer_range
)

if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
from mindone.models.utils import normal_, zeros_

normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
zeros_(module.bias)
elif isinstance(module, mint.nn.Embedding):
from mindone.models.utils import normal_, zeros_

# Embedding uses `embedding_table` in MS nn.Embedding
normal_(module.weight, mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


@dataclass
@auto_docstring(
custom_intro="""
Base class for ColQwen2 embeddings output.
"""
)
class ColQwen2ForRetrievalOutput(ModelOutput):
r"""
loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
embeddings (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The embeddings of the model.
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)

Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
"""

loss: Optional[ms.Tensor] = None
embeddings: Optional[ms.Tensor] = None
past_key_values: Optional[Union[list[ms.Tensor], Cache]] = None
hidden_states: Optional[tuple[ms.Tensor]] = None
attentions: Optional[tuple[ms.Tensor]] = None


@auto_docstring(
custom_intro="""
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
between these document embeddings and the corresponding query embeddings, using the late interaction method
introduced in ColBERT.

Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.

ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
"""
)
class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
_checkpoint_conversion_mapping = {}

def __init__(self, config: ColQwen2Config):
super().__init__(config)
self.config = config
self.vocab_size = config.vlm_config.text_config.vocab_size

self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config)

self.embedding_dim = self.config.embedding_dim
self.embedding_proj_layer = mint.nn.Linear(
self.config.vlm_config.text_config.hidden_size,
self.embedding_dim,
)
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]

self.post_init()

@can_return_tuple
@auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
past_key_values: Optional[Cache] = None,
labels: Optional[ms.Tensor] = None,
inputs_embeds: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[ms.Tensor] = None,
image_grid_thw: Optional[ms.Tensor] = None,
cache_position: Optional[ms.Tensor] = None,
) -> ColQwen2ForRetrievalOutput:
r"""
image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)

# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
if pixel_values is not None and image_grid_thw is not None:
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
pixel_values = mint.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
dim=0,
) # (num_patches_h * num_patches_w, pixel_values)

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

position_ids, rope_deltas = self.vlm.model.get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
attention_mask=attention_mask,
)

# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
if inputs_embeds is None:
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)

if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).broadcast_to(inputs_embeds.shape)
)
image_embeds = image_embeds.to(inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if attention_mask is not None:
pass

vlm_output = self.vlm.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None

last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)

return ColQwen2ForRetrievalOutput(
embeddings=embeddings,
past_key_values=vlm_output.past_key_values,
hidden_states=vlm_hidden_states,
attentions=vlm_output.attentions,
)

def get_input_embeddings(self):
return self.vlm.get_input_embeddings()

def set_input_embeddings(self, value):
self.vlm.set_input_embeddings(value)

def get_output_embeddings(self):
return self.vlm.get_output_embeddings()

def set_output_embeddings(self, new_embeddings):
self.vlm.set_output_embeddings(new_embeddings)

def tie_weights(self):
return self.vlm.tie_weights()

def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> mint.nn.Embedding:
model_embeds = self.vlm.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
)

self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vlm_config.vocab_size = model_embeds.num_embeddings
self.vlm.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings

return model_embeds


__all__ = ["ColQwen2ForRetrieval", "ColQwen2PreTrainedModel"]
Loading