-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels_evaluate.py
116 lines (102 loc) · 2.53 KB
/
models_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import model_selection
def plot_loss(trained_model, xlabel):
'''
arguments:trained_model, axis, xlabel, ylabel
purpose: plot the loss of the trained model against the epoch
return: None
'''
model_history = trained_model.history
model_history.update(
{'epoch': list(
range(
len(model_history['val_loss'])
)
)
}
)
model_history = pd.DataFrame.from_dict(model_history)
best_epoch = model_history.sort_values(
by = 'val_loss',
ascending = True,
).iloc[0]['epoch']
if not None:
fig, ax = plt.subplots(1, 2, figsize=(10, 3))
sns.lineplot(
x = 'epochs',
y = 'val_loss',
ax = ax,
data = model_history,
label = 'Validation'
)
sns.lineplot(
x = 'epochs',
y = 'loss',
ax = ax,
data = model_history,
label = 'Training',
)
ax.axvline(
x = best_epoch,
linestyle = '--',
color = 'green',
label = 'Best Epoch'
)
ax.legend(loc = 1)
ax.set_ylim([0.1, 1])
ax.set_xlabel(xlabel)
ax.set_ylabel('Loss (Fraction)')
plt.show()
def plot_accuracy(trained_model, xlabel):
'''
arguments:trained_model, axis, xlabel, ylabel
purpose: plot the loss of the trained model against the epoch
return: None
'''
model_history = trained_model
model_history.update(
{'epoch': list(
range(
len(model_history['val_accuracy'])
)
)
}
)
model_history = pd.DataFrame.from_dict(model_history)
best_epoch = model_history.sort_values(
by = 'val_accuracy',
ascending = False,
).iloc[0]['epoch']
if not None:
fig, ax = plt.subplots(1, 1)
sns.lineplot(
x = 'epoch',
y = 'val_accuracy',
ax = ax,
data = model_history,
label = 'Validation'
)
sns.lineplot(
x = 'epoch',
y = 'accuracy',
ax = ax,
data = model_history,
label = 'Training',
)
ax.axhline(0.5,
linestyle = '--',
color='red',
label = 'Chance')
ax.axvline(
x = best_epoch,
linestyle = '--',
color = 'green',
label = 'Best Epoch'
)
ax.legend(loc = 1)
ax.set_ylim([0.5, 1])
ax.set_xlabel(xlabel)
ax.set_ylabel('Accuracy (Fraction)')
plt.show()