diff --git a/vllm_ascend/models/qwen2_5_vl_without_padding.py b/vllm_ascend/models/qwen2_5_vl_without_padding.py index d51a5aca9a6..d8b03ef1bf0 100644 --- a/vllm_ascend/models/qwen2_5_vl_without_padding.py +++ b/vllm_ascend/models/qwen2_5_vl_without_padding.py @@ -19,6 +19,7 @@ from functools import partial from typing import Callable, Optional +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -444,14 +445,21 @@ def cal_cos_sin(self, rotary_pos_emb): def forward( self, x: torch.Tensor, - grid_thw: list[list[int]], + grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = np.array(grid_thw, dtype=np.int32) + else: + grid_thw_list = grid_thw.tolist() + grid_thw = grid_thw.numpy() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw_list) grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32)