-
Notifications
You must be signed in to change notification settings - Fork 166
/
Copy pathtrain_atepc_chinese.py
32 lines (27 loc) · 1.42 KB
/
train_atepc_chinese.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf-8 -*-
# file: train_atepc.py
# time: 2021/5/21 0021
# author: yangheng <[email protected]>
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
########################################################################################################################
# ATEPC training script #
########################################################################################################################
from pyabsa.functional import ATEPCModelList
from pyabsa.functional import Trainer, ATEPCTrainer
from pyabsa.functional import ABSADatasetList
from pyabsa.functional import ATEPCConfigManager
atepc_config_chinese = ATEPCConfigManager.get_atepc_config_chinese()
atepc_config_chinese.model = ATEPCModelList.FAST_LCF_ATEPC
atepc_config_chinese.evaluate_begin = 0
atepc_config_chinese.pretrained_bert = 'bert-base-chinese'
atepc_config_chinese.log_step = -1
atepc_config_chinese.l2reg = 1e-5
atepc_config_chinese.num_epoch = 30
atepc_config_chinese.cache_dataset = False
chinese_sets = ABSADatasetList.Chinese
aspect_extractor = Trainer(config=atepc_config_chinese,
dataset=chinese_sets,
checkpoint_save_mode=1,
auto_device=True
).load_trained_model()