-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
74 lines (66 loc) · 3 KB
/
run.py
File metadata and controls
74 lines (66 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from main import WuPreTrainer, CantoPreTrainer, CantoNLIFineTuner, CantoPOSFineTuner, CantoDEPSFineTuner, CantoTokenClassificationFineTuner, CantoAcceptabilityFineTuner
from argparse import ArgumentParser
from transformers import Trainer
def run(args):
if args.pretrain:
if args.lang == "yue":
assert args.data in ["wiki",
"cantonese-sentences"], f"{args.data} is not a valid dataset. Choose between 'wiki', 'cantonese-sentences'"
model = CantoPreTrainer(model_dir=args.model_dir, scratch=args.scratch, data=args.data)
model.train()
elif args.lang == "wuu":
model = WuPreTrainer(model_dir=args.model_dir)
model.train()
else:
print(f"{args.lang} pre-training is not supported. Please choose from: yue, wuu")
if args.finetune:
if args.lang == "yue":
if args.task == "nli":
model = CantoNLIFineTuner(args.lang, model_dir=args.model_dir)
model.finetune()
elif args.task == "pos":
model = CantoPOSFineTuner(args.lang, model_dir=args.model_dir)
model.finetune()
elif args.task == "deps":
model = CantoDEPSFineTuner(args.lang, model_dir=args.model_dir)
model.finetune()
elif args.task == "accept":
model = CantoAcceptabilityFineTuner(args.lang, model_dir=args.model_dir)
model.finetune()
else:
print(f"{args.task} fine-tuning is not supported. Please choose from: pos, nli, deps, accept")
else:
print(f"{args.lang} fine-tuning is not supported. Please choose from: yue")
if args.eval_only:
if args.lang == "yue":
model = CantoNLIFineTuner(args.lang, model_dir=args.model_dir, eval_only=True)
trainer = Trainer(
model=model.model,
args=model.training_args,
eval_dataset=model.finetune_dataset["test"]
)
model.eval(trainer)
else:
print(f"{args.lang} evaluating is not supported. Please choose from: yue")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--lang", default="yue")
parser.add_argument("--model_dir", default="./models/bert-base-chinese-local")
parser.add_argument("--pretrain", action="store_true", default=False)
parser.add_argument("--scratch", action="store_true", default=False)
parser.add_argument("--finetune", action="store_true", default=False)
parser.add_argument("--eval_only", action="store_true", default=False)
parser.add_argument("--data", type=str, default="wiki")
parser.add_argument("--task", type=str, default="")
args = parser.parse_args()
"""
Add your custom arguments for IDE tests here
"""
if args.task:
args.finetune = True
if args.scratch:
args.pretrain = True
if args.eval_only:
args.finetune = False
print(args)
run(args)