-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobserve.py
60 lines (53 loc) · 1.86 KB
/
observe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from model import NetworkSmallDuell
from mechanics import Game
from agent import Agent
import numpy as np
import pygame as pg
import os
import sys
import torch
DIR = ["↑", "↓", "←", "→"]
def resume(fp):
if os.path.isfile(fp):
cp = torch.load(fp, map_location={'cuda:0': 'cpu'})
model = NetworkSmallDuell(32, 4)
model2 = NetworkSmallDuell(32, 4)
model.load_state_dict(cp['state_dict'])
model2.load_state_dict(cp['state_dict2'])
return model, model2
else:
raise FileNotFoundError("File {} not found.".format(fp))
def print_q_values(a, aa):
s = lambda i: "\33[31m{:.2f}\33[m".format(a[i]) if aa == i else "{:.2f}".format(a[i])
print(" " + " ".join(["{}:{}".format(DIR[x], s(x)) for x in range(len(a))]), end="\r")
if __name__ == "__main__":
game = Game(easy=True, size=28)
model, model2 = resume(sys.argv[1])
model.eval()
model2.eval()
ag = Agent(model, cuda=False) # for to_var
resolution = game.get_visual().shape[:2]
screen = pg.display.set_mode(resolution)
while(True):
if np.random.randint(0,2):
model, model2 = model2, model
state = game.get_visual(hud=False)
field = game.get_visual(hud=True)
pg.surfarray.blit_array(screen, field)
a = model(ag.to_var(state)).data[0] # Tensor dim=(4)
m = False
while not m:
aa = a.max(0)[1][0] # argmax as scalar
print_q_values(a, aa)
pg.display.flip()
if input() == "x":
game.game_over()
game.move_player(None)
print("-"*35)
break
if a.max() == float("-inf"):
#this should never happen!
print("No valid moves. Exiting.")
exit(123)
m = game.move_player(aa)
a[aa] = float("-inf")