-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotter3D.py
58 lines (48 loc) · 1.84 KB
/
plotter3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
def update_graph(graph, step: int, frequency: int, total_steps: int, new_state: np.ndarray, filename):
if step % frequency != 0:
return
# x, y, z = graph._verts3d
x = list([s for s in range(new_state.shape[0])])
z = list([min(list([1000*new_state[s][a] for a in range(new_state.shape[1])]))
for s in range(new_state.shape[0])])
y = list([step for _ in range(new_state.shape[0])])
# graph.set_xdata(x)
# graph.set_ydata(y)
# graph.set_3d_properties(z)
if step == total_steps - frequency:
graph.scatter3D(x, y, z, color="red")
plt.show()
plt.draw()
plt.gcf().savefig(filename, dpi=100)
else:
graph.scatter3D(x, y, z, color="green")
def update_graph_iteration(graph, step: int, frequency: int, total_steps: int, new_state: np.ndarray, filename):
if step % frequency != 0:
return
# x, y, z = graph._verts3d
x = list([s for s in range(new_state.shape[0])])
z = list([new_state[s] for s in range(new_state.shape[0])])
y = list([step for _ in range(new_state.shape[0])])
# graph.set_xdata(x)
# graph.set_ydata(y)
# graph.set_3d_properties(z)
if step == total_steps - frequency:
graph.scatter3D(x, y, z, color="red")
plt.show()
plt.draw()
plt.gcf().savefig(filename, dpi=100)
else:
graph.scatter3D(x, y, z, color="green")
def init_graph():
graph = plt.axes(projection='3d')
graph.set_xlabel("states")
graph.set_ylabel("episode")
graph.set_zlabel("value function x1e-3")
graph.text2D(0.05, 0.95, "min value of q function according to action", transform=graph.transAxes)
fig = plt.gcf()
fig.set_size_inches(18.5, 10.5, forward=True)
fig.set_dpi(100)
return graph