-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathconfig.py
37 lines (31 loc) · 938 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import os
config = {
# data
"data_dir":os.path.join(os.getcwd(), 'data'),
"dataset":"switchboard",
"text_field":"Text",
# "label_field":"act_label_1",
"label_field":"DamslActTag",
"max_len":256,
"batch_size":64,
"num_workers":4,
# model
"model_name":"roberta-base", #roberta-base
"hidden_size":768,
"num_classes":43, # there are 43 classes in switchboard corpus
# training
"save_dir":"./",
"project":"dialogue-act-classification",
"run_name":"context-aware-attention-dac",
"lr":1e-5,
"monitor":"val_accuracy",
"min_delta":0.001,
"filepath":"./checkpoints/{epoch}-{val_accuracy:4f}",
"precision":32,
"average":"micro",
"epochs":100,
"device":torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"restart":False,
"restart_checkpoint":"./checkpoints/epoch=10-val_accuracy=0.720291.ckpt"
}