-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_utils.py
More file actions
74 lines (64 loc) · 2.54 KB
/
plot_utils.py
File metadata and controls
74 lines (64 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import matplotlib.pyplot as plt
import numpy as np
def plot_all_and_first_points(X, N):
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].plot(X)
ax[0].set_xlabel('T')
ax[0].set_title('All time series')
ax[1].plot(X[:N])
ax[1].set_xlabel('T')
ax[1].set_title("First {} time points".format(N))
plt.tight_layout()
plt.show()
def plot_gc_graphs(actual, est):
def calc_falsePositives(actual, est):
counter = 0
for i in range(actual.shape[0]):
for j in range(actual.shape[1]):
if actual[i, j] == 0 and est[i, j] == 1:
counter+=1
return counter
def calc_falseNegatives(actual, est):
counter = 0
for i in range(actual.shape[0]):
for j in range(actual.shape[1]):
if actual[i, j] == 1 and est[i, j] == 0:
counter+=1
return counter
def calcTruePositves(actual, est):
counter = 0
for i in range(actual.shape[0]):
for j in range(actual.shape[1]):
if actual[i, j] == 1 and est[i, j] == 1:
counter+=1
return counter
print('True variable usage = %.2f%%' % (100 * np.mean(actual)))
print('Estimated variable usage = %.2f%%' % (100 * np.mean(est)))
print('Accuracy = %.2f%%' % (100 * np.mean(actual == est)))
pre = (100 * calcTruePositves(actual, est) / (calcTruePositves(actual, est) + calc_falsePositives(actual, est)))
rec = (100 * calcTruePositves(actual, est) / (calcTruePositves(actual, est) + calc_falseNegatives(actual, est)))
f1 = (2*pre*rec)/(pre+rec)
print('Precision = %.2f%%' % pre)
print('Recall = %.2f%%' % rec)
print('F1 Score = %.2f%%' % f1)
# Make figures
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
axarr[0].imshow(actual, cmap='Blues')
axarr[0].set_title('GC actual')
axarr[0].set_ylabel('Affected series')
axarr[0].set_xlabel('Causal series')
axarr[0].set_xticks([])
axarr[0].set_yticks([])
axarr[1].imshow(est, cmap='Blues', vmin=0, vmax=1, extent=(0, len(est), len(est), 0))
axarr[1].set_title('GC estimated')
axarr[1].set_ylabel('Affected series')
axarr[1].set_xlabel('Causal series')
axarr[1].set_xticks([])
axarr[1].set_yticks([])
# Mark disagreements
for i in range(len(est)):
for j in range(len(est)):
if actual[i, j] != est[i, j]:
rect = plt.Rectangle((j, i-0.05), 1, 1, facecolor='none', edgecolor='red', linewidth=1)
axarr[1].add_patch(rect)
plt.show()