-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_stats.py
64 lines (52 loc) · 2.3 KB
/
predict_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import Functions as fn
import matplotlib.pyplot as plt
import h5py
import importlib
importlib.reload(fn)
# transfer = 'reformed_spectra_densesapce_safe.hdf5'
transfer = 'reformed_TF_train_mp_1.hdf5'
h5_reformed = h5py.File(transfer, 'r')
if 'VN_coeff' not in h5_reformed:
raise Exception('No "VN_coeff" in file.')
else:
VN_coeff = h5_reformed['VN_coeff']
for key in list(h5_reformed.keys()):
print('shape of {} is {}'.format(key, h5_reformed[key].shape))
cuts = [[0, 6000], [6000, 12000], [12000, 18000], [18000, 24000], [24000, 30000]]
cuts = [[0, 6000], [6000, 12000]]
cuts = [[0, 6000], [6000, 12000]]
cuts = [[0,100]]
mag_error = np.zeros(shape=VN_coeff.shape)
phase_error = np.zeros(shape=VN_coeff.shape)
for run, cut in enumerate(cuts):
predictions = classifier.predict(
input_fn=lambda: fn.predict_hdf5_functor(transfer=transfer, select=cut,
batch_size=1))
ground_truther = VN_coeff[cut[0]:cut[1], ...]
predict_truther = np.zeros_like(ground_truther, dtype='complex64')
i = 0
for predict in predictions:
predict_truther[i, ...] = predict['output'][0:100] + 1J * predict['output'][100:200]
i = i + 1
abs_ground_truther = np.abs(ground_truther)
abs_predict_truther = np.abs(predict_truther)
theta_ground_truther = np.angle(ground_truther)
theta_predict_truther = np.angle(predict_truther)
'''
mag_error[cut[0]:cut[1], ...] = 2*(np.abs(ground_truther + predict_truther) - np.abs(
ground_truther - predict_truther)) / (np.abs(ground_truther + predict_truther) + np.abs(
ground_truther - predict_truther))
phase_error[cut[0]:cut[1], ...] = 2 * (np.abs(theta_predict_truther - theta_ground_truther)) / (np.abs(
theta_predict_truther) + np.abs(theta_ground_truther))
'''
mag_error[cut[0]:cut[1], ...] = np.abs(ground_truther - predict_truther)
phase_error[cut[0]:cut[1], ...] = np.angle(ground_truther - predict_truther)
print('completed run {}'.format(run))
h5_reformed.close()
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(20, 20))
im = ax[0].pcolormesh(mag_error[cut[0]:cut[1]])
fig.colorbar(im, ax=ax[0])
im = ax[1].pcolormesh(phase_error[cut[0]:cut[1]])
fig.colorbar(im, ax=ax[1])
# fig.savefig('Images/percent_errors2.png', dpi= 700)