Skip to content

BUG when enable LoRA in "./llava/train/train.py" #34

@Marsir04

Description

@Marsir04

After setting:

lora_enable: bool = True

in "./llava/train/train.py" line 118, it occurs that

ValueError: Target module LlamaDecoderLayer(
  (self_attn): LlamaFlashAttention2(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
    (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
    (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm()
  (post_attention_layernorm): LlamaRMSNorm()
) is not supported. Currently, only the following modules are supported: `torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`.

In fact, this is because the module 'video_tower' shall not be adapted by lora, yet it not being removed by function find_all_linear_names.

Therefore, it is necessary to change the code in "./llava/train/train.py" line 185:

multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']

into:

multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler', 'video_tower']

in order to ignore the module 'video_tower'.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions