Skip to content

Latest commit

 

History

History
127 lines (90 loc) · 4.15 KB

t5.md

File metadata and controls

127 lines (90 loc) · 4.15 KB

T5

模型描述

T5:全名Text-to-Text Transfer Transformer模型是谷歌在2019年基于C4数据集训练的Transformer模型。

论文C Raffel,N Shazeer,A Roberts,K Lee,S Narang,M Matena,Y Zhou,W Li,PJ Liu, 2020

数据集准备

使用的数据集:WMT16

对应的文件路径如下:

└── wmt_en_ro
    ├── test.source
    ├── test.target
    ├── train.source
    ├── train.target
    ├── val.source
    └── val.target

快速使用

脚本启动

需开发者提前clone工程。

示例命令如下,将会执行一个只有1层的T5模型训练

python run_mindformer.py --config configs/t5/run_t5_tiny_on_wmt16.yaml --run_mode train  \
                         --device_target Ascend \
                         --train_dataset_dir /your_path/wmt_en_ro

其中device_target根据用户的运行设备不同,可选GPU/Ascend/CPUconfig的入参还可以为configs/t5/run_t5_small.yaml,在 这个配置下将会加载t5_small的权重并且开始执行微调。

调用API启动

需开发者提前pip安装。具体接口说明请参考API接口

Model调用接口

  • 模型计算Loss
from mindformers import T5ForConditionalGeneration, T5Tokenizer

model = T5ForConditionalGeneration.from_pretrained('t5_small')
tokenizer = T5Tokenizer.from_pretrained('t5_small')

src_output = tokenizer(["hello world"], padding='max_length', max_length=model.config.seq_length,
                       return_tensors='ms')

model_input = tokenizer(["So happy to see you!"], padding='max_length', max_length=model.config.max_decode_length,
                        return_tensors='ms')["input_ids"]
input_ids = src_output['input_ids']
attention_mask = src_output['attention_mask']
output = model(input_ids, attention_mask, model_input)
print(output)
# [5.64458]
  • 推理

执行下述的命令,可以自动云上拉取t5_small模型并且进行推理。

from mindformers import T5ForConditionalGeneration, T5Tokenizer

t5 = T5ForConditionalGeneration.from_pretrained("t5_small")
tokenizer = T5Tokenizer.from_pretrained("t5_small")
words = tokenizer("translate the English to the Romanian: UN Chief Says There Is No Military "
                  "Solution in Syria")['input_ids']
output = t5.generate(words, do_sample=False)
output = tokenizer.decode(output, skip_special_tokens=True)
print(output)
# "eful ONU declară că nu există o soluţie militară în Siri"
  • Trainer接口开启训练/预测:
import mindspore; mindspore.set_context(mode=0, device_id=0)
from mindformers.trainer import Trainer
# 初始化预训练任务
trainer = Trainer(task='translation', model='t5_small', train_dataset="your data file path")

# 方式1: 开启训练,并使用训练好的权重进行推理
trainer.train()
res = trainer.predict(predict_checkpoint=True, input_data="translate the English to Romanian: a good boy!")
print(res)
#[{'translation_text': ['un băiat bun!']}]

# 方式2: 从obs下载训练好的权重并进行推理
res = trainer.predict(input_data="translate the English to Romanian: a good boy!")
print(res)
#[{'translation_text': ['un băiat bun!']}]
  • pipeline接口开启快速推理
from mindformers.pipeline import pipeline
pipeline_task = pipeline("translation", model='t5_small')
pipeline_result = pipeline_task("translate the English to Romanian: a good boy!", top_k=3)
print(pipeline_result)
#[{'translation_text': ['un băiat bun!']}]

模型权重

本仓库中的t5_small来自于HuggingFace的t5_small, 基于下述的步骤获取:

  1. 从上述的链接中下载t5_small的HuggingFace权重,文件名为pytorch_model.bin

  2. 执行转换脚本,得到转换后的输出文件mindspore_t5.ckpt

python mindformers/models/t5/convert_weight.py --layers 6 --torch_path pytorch_model.bin --mindspore_path ./mindspore_t5.ckpt