使用视觉和语言指令训练一个多模态聊天机器人!
基于开源多模态模型 OpenFlamingo,我们使用公开数据集创建了各种视觉指令数据,包括视觉问答、图像字幕、视觉推理、文本 OCR 和视觉对话。此外,我们还使用仅包含语言指令数据的语言模型组件进行了训练。
视觉和语言指令的联合训练有效提高了模型的性能!更多细节请参阅我们的技术报告。
欢迎加入我们!
English | 简体中文
- 支持各种视觉和语言指令数据
- 使用 LoRA 进行参数高效微调
- 同时调整视觉和语言,相互补充
在一个已有环境中安装依赖包,运行以下指令
git clone https://github.com/open-mmlab/Multimodal-GPT.git
cd Multimodal-GPT
pip install -r requirements.txt
pip install -v -e .
或者创建一个新的 conda 环境
conda env create -f environment.yml
-
下载预训练权重
使用这个脚本把 LLaMA 权重转换成 HuggingFace 格式。
从 openflamingo/OpenFlamingo-9B 下载 OpenFlamingo 预训练模型。
从这个链接 下载我们的 LoRA 权重。
然后把所有模型权重放到
checkpoints
文件夹下,目录结构如下:checkpoints ├── llama-7b_hf │ ├── config.json │ ├── pytorch_model-00001-of-00002.bin │ ├── ...... │ └── tokenizer.model ├── OpenFlamingo-9B │ └──checkpoint.pt ├──mmgpt-lora-v0-release.pt
-
启动 gradio demo
python app.py
-
从这个链接下载标注,解压到
data/aokvqa/annotations
路径下。同时还需要 coco 数据集的图像,可以从这里下载。
-
从这个链接,解压到
data/coco
路径下。同时还需要 coco 数据集的图像,可以从这里下载。
-
从 这个链接 下载数据集,放到
data/OCR_VQA/
路径下。 -
从 liuhaotian/LLaVA-Instruct-150K 下载数据集,放到
data/llava/
路径下。同时还需要 coco 数据集的图像,可以从这里下载。
-
从 Vision-CAIR/cc_sbu_align 下载数据集,放到
data/cc_sbu_align/
路径下。 -
从 databricks/databricks-dolly-15k 下载数据集,放到
data/dolly/databricks-dolly-15k.jsonl
路径下。 -
从这个链接 下载数据集,放到
data/alpaca_gpt4/alpaca_gpt4_data.json
路径下。
你也可以在 configs/dataset_config.py 文件中自定义数据集路径。
torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \
--lm_path checkpoints/llama-7b_hf \
--tokenizer_path checkpoints/llama-7b_hf \
--pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \
--run_name train-my-gpt4 \
--learning_rate 1e-5 \
--lr_scheduler cosine \
--batch_size 1 \
--tuning_config configs/lora_config.py \
--dataset_config configs/dataset_config.py \
--report_to_wandb
如果你觉得我们的项目对你的研究和应用有帮助,请用以下 BibTeX 进行引用
@misc{gong2023multimodalgpt,
title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans},
author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen},
year={2023},
eprint={2305.04790},
archivePrefix={arXiv},
primaryClass={cs.CV}
}