Skip to content

JunnYu/GPLinker_pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

e00545b · May 10, 2022

History

12 Commits
Mar 3, 2022
May 10, 2022
Mar 3, 2022
Mar 3, 2022
Mar 3, 2022
Mar 3, 2022
Feb 25, 2022
May 10, 2022
May 10, 2022
Mar 3, 2022
Mar 3, 2022

Repository files navigation

GPLinker_pytorch

GPLinker_pytorch

介绍

这是pytorch版本的GPLinker代码以及TPLinker_Plus代码。

更新

  • 2022/03/03 添加tplinker_plus+bert-base-chinese权重在duie_v1上的结果。添加duee_v1任务的训练代码,请查看duee_v1目录
  • 2022/03/01 添加tplinker_plus+hfl/chinese-roberta-wwm-ext权重在duie_v1上的结果。
  • 2022/02/25 现已在Dev分支更新最新的huggingface全家桶版本的代码,main分支是之前旧的代码(执行效率慢)

结果

Tips: 在RTX309020epoch的条件下,gplinker需要训练5-6htplinker_plus则需要训练16-17h

dataset method pretrained_model_name_or_path f1 precision recall
duie_v1 gplinker hfl/chinese-roberta-wwm-ext 0.8214065255731926 0.8250077498782166 0.8178366038895478
duie_v1 gplinker bert-base-chinese 0.8198087178424598 0.8146470447994109 0.8250362175688137
duie_v1 tplinker_plus hfl/chinese-roberta-wwm-ext 0.8256425523469291 0.8295114656031908 0.8218095614381671
duie_v1 tplinker_plus bert-base-chinese 0.8216261688290682 0.8076458240569943 0.8360990385881737

Tensorboard日志

gplinker训练日志

tplinker_plus训练日志

依赖

所需的依赖如下:

  • fastcore==1.3.29
  • datasets==1.18.3
  • transformers>=4.16.2
  • accelerate==0.5.1
  • chinesebert==0.2.1

安装依赖requirements.txt

pip install -r requirements.txt

准备数据

http://ai.baidu.com/broad/download?dataset=sked 下载数据。

train_data.jsondev_data.json压缩成spo.zip文件,并且放入data文件夹。

当前data/spo.zip文件是本人提供精简后的数据集,其中train_data.json只有2000条数据,dev_data.json只有200条数据。

运行

accelerate launch train.py \
    --model_type bert \
    --pretrained_model_name_or_path bert-base-chinese \
    --method gplinker \
    --logging_steps 200 \
    --num_train_epochs 20 \
    --learning_rate 3e-5 \
    --num_warmup_steps_or_radios 0.1 \
    --gradient_accumulation_steps 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 32 \
    --seed 42 \
    --save_steps 10804 \
    --output_dir ./outputs \
    --max_length 128 \
    --topk 1 \
    --num_workers 6

其中使用到参数介绍如下:

  • model_type: 表示模型架构类型,像bert-base-chinesehfl/chinese-roberta-wwm-ext模型都是基于bert架构,junnyu/roformer_chinese_char_base是基于roformer架构,可选择["bert", "roformer", "chinesebert"]
  • pretrained_model_name_or_path: 表示加载的预训练模型权重,可以是本地目录,也可以是huggingface.co的路径。
  • method: 表示使用的方法, 可选择["gplinker", "tplinker_plus"]
  • logging_steps: 日志打印的间隔,默认为200
  • num_train_epochs: 训练轮数,默认为20
  • learning_rate: 学习率,默认为3e-5
  • num_warmup_steps_or_radios: warmup步数或者比率,当为浮点类型时候表示的是radio,当为整型时候表示的是step,默认为0.1
  • gradient_accumulation_steps: 梯度累计的步数,默认为1
  • per_device_train_batch_size: 训练的batch_size,默认为16
  • per_device_eval_batch_size: 评估的batch_size,默认为32
  • seed: 随机种子,以便于复现,默认为42
  • save_steps: 保存步数,每隔多少步保存模型。
  • output_dir: 模型输出路径。
  • max_length: 句子的最大长度,当大于这个长度时候,tokenizer会进行截断处理。
  • topk: 保存topk个数模型,默认为1
  • num_workers: dataloadernum_workers参数,linux系统下发现GPU使用率不高的时候可以尝试设置这个参数大于0,而windows下最好设置为0,不然会报错。
  • use_efficient: 是否使用EfficientGlobalPointer,默认为False

Reference

About

GPLinker_pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published