-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo.py
117 lines (96 loc) · 3.45 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
@leofansq
"""
import time
import matplotlib.pyplot as plt
import numpy as np
import argparse
import sys
from env import ENV
from feature_encoder import FEATURE_ENCODER
from act import ACTOR
############################
#### 仿真参数 ####
############################
T = 30
INTERVAL_A = 0.1
INTERVAL_ENV = 0.01
ACTION = [[5.0, 0.0], [-5.0, 0.0], [0.0, 5.0], [0.0, -5.0]]
def main(EP, VIS, path, FAST):
# 初始化仿真参数
step_a = INTERVAL_A / INTERVAL_ENV
# 初始化特征编码器 & 动作生成器
encoder = FEATURE_ENCODER(ACTION)
actor = ACTOR(encoder, ACTION, is_train=False)
# 加载参数矩阵
try:
w = np.load(path)
print ("Load {}".format(path))
print ("-"*30)
except:
print ("Could not find {}".format(path))
return 0
# 实时可视化的初始化设置
if VIS:
plt.ion()
plt.figure(figsize=(5, 5))
plt.axis([0, 100, 0, 100])
for ep in range(EP):
sys.stdout.write("EP:{} ".format(ep+1))
# 初始化环境
# e = ENV(w=100, h=100, target=[85.0, 85.0], c_x=10.0, c_y=10.0, c_vx=0.0, c_vy=0.0)
# e = ENV(w=100, h=100, c_vx=0.0, c_vy=0.0)
e = ENV(w=100, h=100)
# 可视化
if VIS:
plt.scatter(e.target[0], e.target[1], s=30, c='red')
else:
track_x = []
track_y = []
for t in range(int(T/INTERVAL_ENV)):
if t%step_a == 0:
a = actor.act([e.c.dx, e.c.dy, e.c.vx, e.c.vy], w)
e.update(a)
# 可视化
if VIS and t%FAST==0:
sys.stdout.write ("Ep:{}-{} Vx:{:.2f} Vy:{:.2f} Action:{} \r".format(ep, t+1, e.c.vx, e.c.vy, a))
sys.stdout.flush()
plt.scatter(e.c.x, e.c.y, s=10, c='blue', alpha=0.2)
plt.scatter(e.target[0], e.target[1], s=30, c='red')
plt.pause(0.01)
elif not VIS:
track_x.append(e.c.x)
track_y.append(e.c.y)
str_out = "processing"
if (t+1)%300==0:
sys.stdout.write(str_out[(t+1)//300-1])
sys.stdout.flush()
print (" Final_distance:{:.2f} ".format(-e.r))
if VIS:
plt.scatter(e.c.x, e.c.y, s=30, c='orange')
plt.text(e.c.x, e.c.y-1, "EP{} Dist:{:.2f}".format(ep+1, -e.r))
plt.pause(5)
if not VIS:
plt.scatter(track_x, track_y, s=5, c='blue', alpha=0.2)
plt.scatter(e.target[0], e.target[1], s=30, c='red')
plt.scatter(track_x[-1], track_y[-1], s=30, c='orange')
plt.text(e.c.x, e.c.y-1, "Dist:{:.2f}".format(-e.r))
plt.axis([0, 100, 0, 100])
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-Ep', default=5)
parser.add_argument('-Vis', action='store_true')
parser.add_argument('-Path', default="./p_w_test.npy")
parser.add_argument('-Fast', default="30", help="n times fast for visualization")
args = parser.parse_args()
EP = int(args.Ep)
VIS = args.Vis
PATH = args.Path
FAST = int(args.Fast)
print ("-"*30)
if VIS:
print ("EPISODE:{}\nVIS:{}\nPATH:{}\n{} times fast\n".format(EP, VIS, PATH, FAST))
else:
print ("EPISODE:{}\nVIS:{}\nPATH:{}\n".format(EP, VIS, PATH))
main(EP, VIS, PATH, FAST)