-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_loss.py
77 lines (72 loc) · 1.69 KB
/
plot_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import matplotlib.pyplot as plt
import numpy as np
in_files = [
# 'CELoss/t16k/m16-l5-lrd-res1.0/loss.csv',
# 't16k/m16-l5-lr5d-res0.5/loss.csv',
't48k/m16-l5-lr5d-res0.5/loss.csv',
# 't48k/m16-l5-lr5d-res1.0/loss.csv',
]
in_labels = [
# 'CE16r1.0',
# 'MSE16r0.5',
'MSE48r0.5',
# 'MSE48r1.0',
]
in_colors = [
# 'r',
'b',
'y',
'g',
'c',
'm',
'k',
]
col_labels = [
'epoch', # 0
'tloss', # 1
'vloss',
'thit', # 3
'vhit',
'teff', # 5
'veff',
'tpur', # 7
'vpur',
'tloose', # 9
'vloose',
]
col_symbol = [
'',
'-',
'--',
'-',
'--',
'-o',
'--o',
'-^',
'--^',
'-',
'--',
]
fontsize = 24
ax = plt.subplot(121)
ax.set_title('loss', fontsize=fontsize)
for in_file, in_lable, in_color in zip(in_files, in_labels, in_colors) :
data = np.genfromtxt(in_file, delimiter=',')
for icol in [1, 2] :
plt.plot(data[:,0], data[:,icol], col_symbol[icol], c=in_color, label='{}:{}'.format(in_lable, col_labels[icol]))
plt.legend(loc='best',fontsize=fontsize)
plt.grid()
plt.xlabel("Epoch", fontsize=fontsize)
plt.ylabel("Mean Loss", fontsize=fontsize)
ax = plt.subplot(122)
ax.set_title('hit rate', fontsize=fontsize)
for in_file, in_lable, in_color in zip(in_files, in_labels, in_colors) :
data = np.genfromtxt(in_file, delimiter=',')
for icol in [9,10] :
plt.plot(data[:,0], data[:,icol], col_symbol[icol], c=in_color, label='{}:{}'.format(in_lable, col_labels[icol]))
# plt.legend(loc='best',fontsize=fontsize)
plt.grid()
plt.ylim(0,1)
plt.xlabel("Epoch", fontsize=fontsize)
# plt.ylabel("Hit rate", fontsize=fontsize)
plt.show()