|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | import torch.nn as nn |
7 | | -from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig, |
8 | | - PreTrainedModel, Qwen2_5_VLForConditionalGeneration, |
| 7 | +from transformers import (AutoConfig, AutoProcessor, AutoTokenizer, |
| 8 | + PretrainedConfig, PreTrainedModel, |
| 9 | + Qwen2_5_VLForConditionalGeneration, |
9 | 10 | Qwen2VLForConditionalGeneration) |
10 | 11 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \ |
11 | 12 | Qwen2_5_VisionTransformerPretrainedModel |
@@ -333,9 +334,10 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], |
333 | 334 | # Currently, copying vision encoder on all devices. |
334 | 335 | # NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM. |
335 | 336 | hf_model_config = AutoConfig.from_pretrained(model_path) |
336 | | - vision_model = model_class(config=hf_model_config.vision_config, |
337 | | - torch_dtype=pretrained_config.torch_dtype, |
338 | | - attn_implementation='flash_attention_2') |
| 337 | + vision_model = model_class._from_config( |
| 338 | + hf_model_config.vision_config, |
| 339 | + torch_dtype=pretrained_config.torch_dtype, |
| 340 | + attn_implementation='flash_attention_2') |
339 | 341 | # TODO: Make vision model compatible with meta init mode and load_weights at the same place |
340 | 342 | self.visual = vision_model.to(self.device) |
341 | 343 | self.post_config() |
|
0 commit comments