From b9f27a10e266dbe8233be53a4c8e3af62b1c2cec Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 16 Jan 2025 19:28:12 +0800 Subject: [PATCH 1/6] init --- .../llm/src/ipex_llm/transformers/convert.py | 5 + .../ipex_llm/transformers/models/qwen2_vl.py | 93 ++++++++++++++++++- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 2979b1bdc8d..9ea02ccfcef 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1646,6 +1646,8 @@ def _optimize_post(model): from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward + from ipex_llm.transformers.models.qwen2_vl import qwen2_vit_pretrained_model_forward + from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_vision_block_forward convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual) @@ -1654,6 +1656,9 @@ def _optimize_post(model): convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward) convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward) convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward) + convert_forward(model, module.Qwen2VisionTransformerPretrainedModel, + qwen2_vit_pretrained_model_forward) + convert_forward(model, module.Qwen2VLVisionBlock, qwen2_vl_vision_block_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index d885e23abfa..b0fba14bba6 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -41,6 +41,7 @@ from typing import Optional, Tuple, Union, List import torch +import torch.nn.functional as F from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.common import scaled_dot_product_attention @@ -187,7 +188,8 @@ def qwen2_vision_attention_forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None + rotary_pos_emb: torch.Tensor = None, + output_attentions: bool = False, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1 @@ -195,12 +197,16 @@ def qwen2_vision_attention_forward( q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) # q, k, v: [seq_length, num_heads, head_dim] + + # TODO: before rope? + # raw_key_states = k.clone() seq_lens = cu_seqlens.tolist() invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length, "unexpected input") - if use_sdp_non_causal(self.head_dim, q.device, q.dtype): + if use_sdp_non_causal(self.head_dim, q.device, q.dtype) and not output_attentions: + # TODO: return attn_weights & attn_output image_num = len(seq_lens) - 1 image_size = seq_lens[1] - seq_lens[0] guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size, @@ -261,7 +267,88 @@ def qwen2_vision_attention_forward( attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) - return attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights#, raw_key_states.mean(1) + + +def qwen2_vl_vision_block_forward( + self, + hidden_states, + cu_seqlens, + rotary_pos_emb, + output_attentions: Optional[bool] = False, +) -> torch.Tensor: + residual = hidden_states + hidden_states, attn_weights = self.attn( + self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, + output_attentions=output_attentions + ) + hidden_states = residual + hidden_states + + # TODO: uncomment & test + # r = self._info["r"].pop(0) + # if r > 0: + # self.metric = metric + + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +def qwen2_vit_pretrained_model_forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor +) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + use_visionzip = True + last_layer_attention = None + + total_blk_num = len(self.blocks) + for idx in range(total_blk_num): + output_attentions = idx == (total_blk_num - 1) and use_visionzip + layer_outputs = self.blocks[idx](hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, + output_attentions=output_attentions) + hidden_states = layer_outputs[0] # 1564, 1280 + + if output_attentions: + last_layer_attention = layer_outputs[1] # 16, 1564, 1564 + + # TODO: select visionzip hidden states + dominant_num = 512 + contextual_num = 10 + + # Dominant Visual Tokens + # TODO: batch dim + attention_mean = last_layer_attention.mean(0) # 1564, 1564 + attention_mean = attention_mean.mean(0) # 1564 + top_k_indices = attention_mean.topk(dominant_num, dim=0).indices + + mask = torch.ones_like( + hidden_states[:, 0], + dtype=torch.bool, + device=hidden_states.device).scatter_(0, top_k_indices, False) + + dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(dominant_num, hidden_states.shape[1]) + + hidden_ststes_save = dominant_tokens.to(hidden_states.dtype) + + return self.merger(hidden_ststes_save) def qwen2_vl_attention_forward( From a9154132e16acb1df2713e20db27fc610a8bc984 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 17 Jan 2025 11:10:37 +0800 Subject: [PATCH 2/6] move condition forward --- .../llm/src/ipex_llm/transformers/convert.py | 3 + .../ipex_llm/transformers/models/qwen2_vl.py | 127 +++++++++++++++--- 2 files changed, 115 insertions(+), 15 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9ea02ccfcef..780e5ebbaf9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1648,6 +1648,7 @@ def _optimize_post(model): from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vit_pretrained_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_vision_block_forward + from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_conditional_generation_forward convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual) @@ -1659,6 +1660,8 @@ def _optimize_post(model): convert_forward(model, module.Qwen2VisionTransformerPretrainedModel, qwen2_vit_pretrained_model_forward) convert_forward(model, module.Qwen2VLVisionBlock, qwen2_vl_vision_block_forward) + convert_forward(model, module.Qwen2VLForConditionalGeneration, + qwen2_vl_conditional_generation_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index b0fba14bba6..27d69a90b4d 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -54,7 +54,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision -from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.cache_utils import Cache @@ -329,28 +329,125 @@ def qwen2_vit_pretrained_model_forward( if output_attentions: last_layer_attention = layer_outputs[1] # 16, 1564, 1564 - # TODO: select visionzip hidden states - dominant_num = 512 - contextual_num = 10 + # # TODO: select visionzip hidden states + # dominant_num = 512 + # contextual_num = 10 - # Dominant Visual Tokens - # TODO: batch dim - attention_mean = last_layer_attention.mean(0) # 1564, 1564 - attention_mean = attention_mean.mean(0) # 1564 - top_k_indices = attention_mean.topk(dominant_num, dim=0).indices + # # Dominant Visual Tokens + # # TODO: batch dim + # attention_mean = last_layer_attention.mean(0) # 1564, 1564 + # attention_mean = attention_mean.mean(0) # 1564 + # top_k_indices = attention_mean.topk(dominant_num, dim=0).indices - mask = torch.ones_like( - hidden_states[:, 0], - dtype=torch.bool, - device=hidden_states.device).scatter_(0, top_k_indices, False) + # mask = torch.ones_like( + # hidden_states[:, 0], + # dtype=torch.bool, + # device=hidden_states.device).scatter_(0, top_k_indices, False) - dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(dominant_num, hidden_states.shape[1]) + # dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(dominant_num, hidden_states.shape[1]) - hidden_ststes_save = dominant_tokens.to(hidden_states.dtype) + # hidden_ststes_save = dominant_tokens.to(hidden_states.dtype) + + hidden_ststes_save = hidden_states return self.merger(hidden_ststes_save) +def qwen2_vl_conditional_generation_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, +) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + import time + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + # t3 = time.perf_counter() + # print(inputs_embeds.shape, "pixel_values_videos time2: ", (t3 - t2) * 1000, " ms") + + # if inputs_embeds is None: + # inputs_embeds = self.model.embed_tokens(input_ids) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + + # position_ids = position_ids[:, :, : inputs_embeds.shape[1]] + # attention_mask = attention_mask[:, : inputs_embeds.shape[1]] + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + # if labels is not None: + # # Shift so that tokens < n predict n + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = CrossEntropyLoss() + # shift_logits = shift_logits.view(-1, self.config.vocab_size) + # shift_labels = shift_labels.view(-1) + # # Enable model parallelism + # shift_labels = shift_labels.to(shift_logits.device) + # loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + def qwen2_vl_attention_forward( self, hidden_states: torch.Tensor, From 4a4e875f88b7288fd4beaac70265fbefaac78e25 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 17 Jan 2025 17:37:24 +0800 Subject: [PATCH 3/6] trail 1 assign random height & width, fix rope delta --- .../llm/src/ipex_llm/transformers/convert.py | 8 ++- .../ipex_llm/transformers/models/qwen2_vl.py | 65 ++++++++++++++++++- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 780e5ebbaf9..aa96865b1f3 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1182,9 +1182,6 @@ def replace_RotaryEmbed(m, target_m, replace_embed): def replace_func(m, target_m, func_name, new_func): for _, sub_m in m.named_children(): - if sub_m.__class__ == target_m: - bound_method = new_func.__get__(sub_m, sub_m.__class__) - setattr(sub_m, func_name, bound_method) replace_func(sub_m, target_m, func_name, new_func) @@ -1649,6 +1646,7 @@ def _optimize_post(model): from ipex_llm.transformers.models.qwen2_vl import qwen2_vit_pretrained_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_vision_block_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_conditional_generation_forward + from ipex_llm.transformers.models.qwen2_vl import _update_model_kwargs_for_generation convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual) @@ -1662,6 +1660,10 @@ def _optimize_post(model): convert_forward(model, module.Qwen2VLVisionBlock, qwen2_vl_vision_block_forward) convert_forward(model, module.Qwen2VLForConditionalGeneration, qwen2_vl_conditional_generation_forward) + import types + model._update_model_kwargs_for_generation = types.MethodType(_update_model_kwargs_for_generation, model) + # replace_func(model, module.Qwen2VLForConditionalGeneration, + # "_update_model_kwargs_for_generation", _update_model_kwargs_for_generation) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index 27d69a90b4d..f28659a7db5 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -38,7 +38,7 @@ # import math -from typing import Optional, Tuple, Union, List +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -55,14 +55,40 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast -from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers.cache_utils import Cache +from transformers import GenerationMixin def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, Qwen2VLAttention) +def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any] = None, + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, +) -> Dict[str, Any]: + model_kwargs = GenerationMixin._update_model_kwargs_for_generation( + self, + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + if model_kwargs.get("use_cache", True): + cache_num = outputs.past_key_values.seen_tokens + model_kwargs['cache_position'] = torch.tensor([cache_num]) + + if getattr(outputs, "rope_deltas", None) is not None: + model_kwargs["rope_deltas"] = outputs.rope_deltas + + return model_kwargs + + def qwen2_vl_model_forward( self, input_ids: torch.LongTensor = None, @@ -381,12 +407,47 @@ def qwen2_vl_conditional_generation_forward( if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: + t1 = time.perf_counter() pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, + attention_mask=attention_mask) + + # # # # TODO: remove redundant image_pad + # new_image_pad_num = image_embeds.shape[0] + # image_pad_indices = torch.where(input_ids == self.config.image_token_id)[1] + # image_pad_start_idx = image_pad_indices[0] + # image_pad_end_idx = image_pad_indices[-1] + # new_image_pad_end_idx = image_pad_start_idx + new_image_pad_num + # input_ids = torch.cat([input_ids[:, : new_image_pad_end_idx], + # input_ids[:, image_pad_end_idx + 1:]], dim=1) + # # # inputs_embeds = torch.cat([inputs_embeds[:, :new_image_pad_end_idx, :], + # # # inputs_embeds[:, image_pad_end_idx + 1:, :]], dim=1) + # inputs_embeds = self.model.embed_tokens(input_ids) + # image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + # image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + # inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # new_grid_thw = torch.tensor([[1, 16, 32]]) + # attention_mask = attention_mask[:, :inputs_embeds.shape[1]] + # position_ids, rope_deltas = self.get_rope_index( + # input_ids, new_grid_thw, + # attention_mask=attention_mask) + + torch.xpu.synchronize() + t2 = time.perf_counter() + # # image_token_num = image_embeds.shape[0] + # # dominant = 128 + # # contexual = 10 + # # diff = image_token_num - dominant - contexual + # # inputs_embeds_token = inputs_embeds.shape[1] + # # inputs_embeds = inputs_embeds[:, : inputs_embeds_token - diff, :] + print(inputs_embeds.shape, "time1: ", (t2 - t1) * 1000, " ms") if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) From a1ba13399a664846898613a57d8dd8b76898cd43 Mon Sep 17 00:00:00 2001 From: cyita Date: Tue, 21 Jan 2025 11:58:50 +0800 Subject: [PATCH 4/6] fix output & support generate with dominant/contextual num --- .../llm/src/ipex_llm/transformers/convert.py | 4 + .../ipex_llm/transformers/models/qwen2_vl.py | 365 +++++++++++++++--- 2 files changed, 309 insertions(+), 60 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index aa96865b1f3..b533273843e 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1647,6 +1647,8 @@ def _optimize_post(model): from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_vision_block_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_conditional_generation_forward from ipex_llm.transformers.models.qwen2_vl import _update_model_kwargs_for_generation + from ipex_llm.transformers.models.qwen2_vl import get_rope_index + from ipex_llm.transformers.models.qwen2_vl import prepare_inputs_for_generation convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual) @@ -1662,6 +1664,8 @@ def _optimize_post(model): qwen2_vl_conditional_generation_forward) import types model._update_model_kwargs_for_generation = types.MethodType(_update_model_kwargs_for_generation, model) + model.get_rope_index = types.MethodType(get_rope_index, model) + model.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, model) # replace_func(model, module.Qwen2VLForConditionalGeneration, # "_update_model_kwargs_for_generation", _update_model_kwargs_for_generation) elif model.config.model_type == "aquila": diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index f28659a7db5..723200a2fd6 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -55,8 +55,9 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast +from transformers.models.qwen2_vl.modeling_qwen2_vl import _prepare_4d_causal_attention_mask_with_cache_position from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput -from transformers.cache_utils import Cache +from transformers.cache_utils import Cache, StaticCache from transformers import GenerationMixin @@ -79,9 +80,9 @@ def _update_model_kwargs_for_generation( num_new_tokens=num_new_tokens, ) - if model_kwargs.get("use_cache", True): - cache_num = outputs.past_key_values.seen_tokens - model_kwargs['cache_position'] = torch.tensor([cache_num]) + # if model_kwargs.get("use_cache", True): + # cache_num = outputs.past_key_values.seen_tokens + # model_kwargs['cache_position'] = torch.tensor([cache_num]) if getattr(outputs, "rope_deltas", None) is not None: model_kwargs["rope_deltas"] = outputs.rope_deltas @@ -89,6 +90,99 @@ def _update_model_kwargs_for_generation( return model_kwargs +def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, +): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + rope_deltas = kwargs.get("rope_deltas", None) + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + dominant_num = kwargs.get("dominant_num", None) + contextual_num = kwargs.get("contextual_num", None) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "rope_deltas": rope_deltas, + "dominant_num": dominant_num, + "contextual_num": contextual_num + } + ) + return model_inputs + + def qwen2_vl_model_forward( self, input_ids: torch.LongTensor = None, @@ -332,7 +426,9 @@ def qwen2_vl_vision_block_forward( def qwen2_vit_pretrained_model_forward( self, hidden_states: torch.Tensor, - grid_thw: torch.Tensor + grid_thw: torch.Tensor, + dominant_num: Optional[int] = None, + contextual_num: Optional[int] = None, ) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) @@ -342,7 +438,8 @@ def qwen2_vit_pretrained_model_forward( ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - use_visionzip = True + # TODO: contextual num + use_visionzip = dominant_num is not None and dominant_num > 1 last_layer_attention = None total_blk_num = len(self.blocks) @@ -355,29 +452,173 @@ def qwen2_vit_pretrained_model_forward( if output_attentions: last_layer_attention = layer_outputs[1] # 16, 1564, 1564 - # # TODO: select visionzip hidden states - # dominant_num = 512 - # contextual_num = 10 - - # # Dominant Visual Tokens - # # TODO: batch dim - # attention_mean = last_layer_attention.mean(0) # 1564, 1564 - # attention_mean = attention_mean.mean(0) # 1564 - # top_k_indices = attention_mean.topk(dominant_num, dim=0).indices - - # mask = torch.ones_like( - # hidden_states[:, 0], - # dtype=torch.bool, - # device=hidden_states.device).scatter_(0, top_k_indices, False) - - # dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(dominant_num, hidden_states.shape[1]) + if use_visionzip: + # Dominant Visual Tokens + # TODO: batch dim + # v1: simple select + # attention_mean = last_layer_attention.mean(0) # 1564, 1564 + # attention_mean = attention_mean.mean(0) # 1564 + # top_k_indices = attention_mean.topk(dominant_num, dim=0).indices + + # v2: select 4 token pairs + attention_mean = last_layer_attention.mean(0) # 1564, 1564 + attention_mean = attention_mean.reshape(attention_mean.shape[0] * self.spatial_merge_size ** 2, -1) + attention_mean = attention_mean.mean(0) # 391 + top_k_indices = attention_mean.topk(dominant_num // 4, dim=0).indices + # TODO: get height & width + # interval_size = 22 + # ranges = [(start, start + interval_size - 1) for start in range(0, 391, interval_size)] + + # # Count the elements in each range + # counts = [] + # for start, end in ranges: + # count = ((top_k_indices >= start) & (top_k_indices <= end)).sum().item() + # counts.append((start, end, count)) + + top_k_indices_copy = top_k_indices.clone() + + top_k_indices = top_k_indices * 4 + top_k_indices = torch.cat([top_k_indices, top_k_indices + 1, top_k_indices + 2, top_k_indices + 3]) + + # v3: select 4 token pairs, another dim (attention mean all equal) + # attention_mean = last_layer_attention.mean(0) # 1564, 1564 + # attention_mean = attention_mean.reshape(-1, attention_mean.shape[0] * self.spatial_merge_size ** 2) + # attention_mean = attention_mean.mean(1) # 1564 + # top_k_indices = attention_mean.topk(dominant_num // 4, dim=0).indices * 4 + # top_k_indices = torch.cat([top_k_indices, top_k_indices + 1, top_k_indices + 2, top_k_indices + 3]) + + mask = torch.ones_like( + hidden_states[:, 0], + dtype=torch.bool, + device=hidden_states.device).scatter_(0, top_k_indices, False) + + dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(dominant_num, hidden_states.shape[1]) + + hidden_ststes_save = dominant_tokens.to(hidden_states.dtype) + else: + hidden_ststes_save = hidden_states + top_k_indices_copy = None - # hidden_ststes_save = dominant_tokens.to(hidden_states.dtype) + return self.merger(hidden_ststes_save), top_k_indices_copy - hidden_ststes_save = hidden_states - return self.merger(hidden_ststes_save) +def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + selected_indices: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + # TODO: selected indices batch + if selected_indices is not None: + mask = torch.ones_like( + t_index, + dtype=torch.bool, + device=t_index.device).scatter_(0, selected_indices.to(t_index.device), False) + t_index = t_index.masked_select(~mask) + h_index = h_index.masked_select(~mask) + w_index = w_index.masked_select(~mask) + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + if selected_indices is not None: + position_ids = position_ids[:, :, :llm_positions.shape[1]] + position_ids = llm_positions.to(position_ids.device) + else: + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas def qwen2_vl_conditional_generation_forward( self, @@ -396,6 +637,8 @@ def qwen2_vl_conditional_generation_forward( image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, + dominant_num: Optional[int] = None, + contextual_num: Optional[int] = None, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -409,44 +652,46 @@ def qwen2_vl_conditional_generation_forward( if pixel_values is not None: t1 = time.perf_counter() pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, - attention_mask=attention_mask) - - # # # # TODO: remove redundant image_pad - # new_image_pad_num = image_embeds.shape[0] - # image_pad_indices = torch.where(input_ids == self.config.image_token_id)[1] - # image_pad_start_idx = image_pad_indices[0] - # image_pad_end_idx = image_pad_indices[-1] - # new_image_pad_end_idx = image_pad_start_idx + new_image_pad_num - # input_ids = torch.cat([input_ids[:, : new_image_pad_end_idx], - # input_ids[:, image_pad_end_idx + 1:]], dim=1) - # # # inputs_embeds = torch.cat([inputs_embeds[:, :new_image_pad_end_idx, :], - # # # inputs_embeds[:, image_pad_end_idx + 1:, :]], dim=1) - # inputs_embeds = self.model.embed_tokens(input_ids) - # image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - # image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - # inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - # new_grid_thw = torch.tensor([[1, 16, 32]]) - # attention_mask = attention_mask[:, :inputs_embeds.shape[1]] - # position_ids, rope_deltas = self.get_rope_index( - # input_ids, new_grid_thw, - # attention_mask=attention_mask) + image_embeds, selected_indices = self.visual(pixel_values, grid_thw=image_grid_thw, + dominant_num=dominant_num, + contextual_num=contextual_num) + if selected_indices is None: + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + else: + # Remove redundant |image_pad| and get selected position ids. + # v2: previous position id + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, + attention_mask=attention_mask, + selected_indices=selected_indices) + attention_mask = attention_mask[:, :inputs_embeds.shape[1]] + + new_image_pad_num = image_embeds.shape[0] + image_pad_indices = torch.where(input_ids == self.config.image_token_id)[1] + image_pad_start_idx = image_pad_indices[0] + image_pad_end_idx = image_pad_indices[-1] + new_image_pad_end_idx = image_pad_start_idx + new_image_pad_num + input_ids = torch.cat([input_ids[:, : new_image_pad_end_idx], + input_ids[:, image_pad_end_idx + 1:]], dim=1) + # # inputs_embeds = torch.cat([inputs_embeds[:, :new_image_pad_end_idx, :], + # # inputs_embeds[:, image_pad_end_idx + 1:, :]], dim=1) + inputs_embeds = self.model.embed_tokens(input_ids) + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + + # v1: random height & width + # attention_mask = attention_mask[:, :inputs_embeds.shape[1]] + # new_grid_thw = torch.tensor([[1, 16, 32]]) + # position_ids, rope_deltas = self.get_rope_index( + # input_ids, new_grid_thw, + # attention_mask=attention_mask) torch.xpu.synchronize() t2 = time.perf_counter() - # # image_token_num = image_embeds.shape[0] - # # dominant = 128 - # # contexual = 10 - # # diff = image_token_num - dominant - contexual - # # inputs_embeds_token = inputs_embeds.shape[1] - # # inputs_embeds = inputs_embeds[:, : inputs_embeds_token - diff, :] print(inputs_embeds.shape, "time1: ", (t2 - t1) * 1000, " ms") if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) From 386dcec2f358072d9a9bde4c967484cd044ecab8 Mon Sep 17 00:00:00 2001 From: cyita Date: Tue, 21 Jan 2025 14:14:37 +0800 Subject: [PATCH 5/6] update replace_func --- python/llm/src/ipex_llm/transformers/convert.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index b533273843e..0ea5d78d9d9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1181,6 +1181,9 @@ def replace_RotaryEmbed(m, target_m, replace_embed): def replace_func(m, target_m, func_name, new_func): + if m.__class__ == target_m: + bound_method = new_func.__get__(m, m.__class__) + setattr(m, func_name, bound_method) for _, sub_m in m.named_children(): replace_func(sub_m, target_m, func_name, new_func) @@ -1662,12 +1665,12 @@ def _optimize_post(model): convert_forward(model, module.Qwen2VLVisionBlock, qwen2_vl_vision_block_forward) convert_forward(model, module.Qwen2VLForConditionalGeneration, qwen2_vl_conditional_generation_forward) - import types - model._update_model_kwargs_for_generation = types.MethodType(_update_model_kwargs_for_generation, model) - model.get_rope_index = types.MethodType(get_rope_index, model) - model.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, model) - # replace_func(model, module.Qwen2VLForConditionalGeneration, - # "_update_model_kwargs_for_generation", _update_model_kwargs_for_generation) + replace_func(model, module.Qwen2VLForConditionalGeneration, + "_update_model_kwargs_for_generation", _update_model_kwargs_for_generation) + replace_func(model, module.Qwen2VLForConditionalGeneration, + "get_rope_index", get_rope_index) + replace_func(model, module.Qwen2VLForConditionalGeneration, + "prepare_inputs_for_generation", prepare_inputs_for_generation) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) From 8ea2beedd374f019b3479486b240869e751e0a77 Mon Sep 17 00:00:00 2001 From: cyita Date: Tue, 21 Jan 2025 15:49:11 +0800 Subject: [PATCH 6/6] update --- .../llm/src/ipex_llm/transformers/convert.py | 3 -- .../ipex_llm/transformers/models/qwen2_vl.py | 31 +++---------------- 2 files changed, 4 insertions(+), 30 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 0ea5d78d9d9..1f6e840f9a5 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1649,7 +1649,6 @@ def _optimize_post(model): from ipex_llm.transformers.models.qwen2_vl import qwen2_vit_pretrained_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_vision_block_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_conditional_generation_forward - from ipex_llm.transformers.models.qwen2_vl import _update_model_kwargs_for_generation from ipex_llm.transformers.models.qwen2_vl import get_rope_index from ipex_llm.transformers.models.qwen2_vl import prepare_inputs_for_generation convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) @@ -1665,8 +1664,6 @@ def _optimize_post(model): convert_forward(model, module.Qwen2VLVisionBlock, qwen2_vl_vision_block_forward) convert_forward(model, module.Qwen2VLForConditionalGeneration, qwen2_vl_conditional_generation_forward) - replace_func(model, module.Qwen2VLForConditionalGeneration, - "_update_model_kwargs_for_generation", _update_model_kwargs_for_generation) replace_func(model, module.Qwen2VLForConditionalGeneration, "get_rope_index", get_rope_index) replace_func(model, module.Qwen2VLForConditionalGeneration, diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index 723200a2fd6..f712fb461eb 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -65,31 +65,6 @@ def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, Qwen2VLAttention) -def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any] = None, - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, -) -> Dict[str, Any]: - model_kwargs = GenerationMixin._update_model_kwargs_for_generation( - self, - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - num_new_tokens=num_new_tokens, - ) - - # if model_kwargs.get("use_cache", True): - # cache_num = outputs.past_key_values.seen_tokens - # model_kwargs['cache_position'] = torch.tensor([cache_num]) - - if getattr(outputs, "rope_deltas", None) is not None: - model_kwargs["rope_deltas"] = outputs.rope_deltas - - return model_kwargs - - def prepare_inputs_for_generation( self, input_ids, @@ -464,7 +439,7 @@ def qwen2_vit_pretrained_model_forward( attention_mean = last_layer_attention.mean(0) # 1564, 1564 attention_mean = attention_mean.reshape(attention_mean.shape[0] * self.spatial_merge_size ** 2, -1) attention_mean = attention_mean.mean(0) # 391 - top_k_indices = attention_mean.topk(dominant_num // 4, dim=0).indices + top_k_indices = attention_mean.topk(dominant_num, dim=0).indices # TODO: get height & width # interval_size = 22 # ranges = [(start, start + interval_size - 1) for start in range(0, 391, interval_size)] @@ -492,7 +467,9 @@ def qwen2_vit_pretrained_model_forward( dtype=torch.bool, device=hidden_states.device).scatter_(0, top_k_indices, False) - dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(dominant_num, hidden_states.shape[1]) + dominant_tokens = hidden_states.masked_select( + ~mask.unsqueeze(-1) + ).view(dominant_num * self.spatial_merge_size ** 2, hidden_states.shape[1]) hidden_ststes_save = dominant_tokens.to(hidden_states.dtype) else: