-
Notifications
You must be signed in to change notification settings - Fork 9
/
generate_dataset.py
135 lines (121 loc) · 4.78 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python
# coding: utf-8
import random
# appends parent path to syspath to make ocatari importable
# like it would have been installed as a package
import sys
from copy import deepcopy
from os import path, makedirs
import matplotlib.pyplot as plt
import pandas as pd
from numpy import random
# sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) # noqa
from ocatari.core import OCAtari
from ocatari.utils import load_agent, parser, make_deterministic
# from ocatari.vision.space_invaders import objects_colors
from ocatari.vision.pong import objects_colors
from ocatari.vision.utils import mark_bb, make_darker
import pickle
from tqdm import tqdm
parser.add_argument("-g", "--game", type=str, default="Pong",
help="game to evaluate (e.g. 'Pong')")
parser.add_argument("-i", "--interval", type=int, default=1000,
help="The frame interval (default 10)")
# parser.add_argument("-m", "--mode", choices=["vision", "ram"],
# default="ram", help="The frame interval")
parser.add_argument("-hud", "--hud", action="store_true", default=True, help="Detect HUD")
parser.add_argument("-dqn", "--dqn", action="store_true", default=True, help="Use DQN agent")
opts = parser.parse_args()
# Init the environment
env = OCAtari(opts.game, mode="both", render_mode='rgb_array', hud=True)
observation, info = env.reset()
# Set up an agent
if opts.dqn:
opts.path = f"../models/{opts.game}/dqn.gz"
dqn_agent = load_agent(opts, env.action_space.n)
# make environment deterministic
env.step(2)
make_deterministic(42, env)
# Init an empty dataset
game_nr = 0
turn_nr = 0
dataset = {"INDEX": [], #"OBS": [],
"RAM": [], "VIS": [], "HUD": []}
frames = []
r_objs = []
v_objs = []
# Generate 10,000 samples
for i in tqdm(range(10000)):
action = dqn_agent.draw_action(env.dqn_obs)
obs, reward, terminated, truncated, info = env.step(action)
# make a short print every 1000 steps
# if i % 1000 == 0:
# print(f"{i} done")
step = f"{'%0.5d' % (game_nr)}_{'%0.5d' % (turn_nr)}"
dataset["INDEX"].append(step)
frames.append(deepcopy(obs))
r_objs.append(deepcopy(env.objects))
v_objs.append(deepcopy(env.objects_v))
# dataset["OBS"].append(obs.flatten().tolist())
dataset["RAM"].append([x for x in sorted(env.objects, key=lambda o: str(o)) if x.hud == False])
dataset["VIS"].append([x for x in sorted(env.objects_v, key=lambda o: str(o))])
dataset["HUD"].append([x for x in sorted(env.objects, key=lambda o: str(o)) if x.hud == True])
turn_nr = turn_nr + 1
# if a game is terminated, restart with a new game and update turn and game counter
if terminated or truncated:
observation, info = env.reset()
turn_nr = 0
game_nr = game_nr + 1
# The interval defines how often images are saved as png files in addition to the dataset
if i % opts.interval == 0:
"""
print("-"*50)
print(f"Frame {i}")
print("-"*50)
fig, axes = plt.subplots(1, 2)
for obs, objects_list, title, ax in zip([obs,obs2], [env.objects, env.objects_v], ["ram", "vis"], axes):
print(f"{title}: ", sorted(objects_list, key=lambda o: str(o)))
for obj in objects_list:
opos = obj.xywh
ocol = obj.rgb
sur_col = make_darker(ocol, 0.2)
mark_bb(obs, opos, color=sur_col)
# mark_point(obs, *opos[:2], color=(255, 255, 0))
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(obs)
ax.set_title(title)
plt.suptitle(f"frame {i}", fontsize=20)
plt.show()
fig2 = plt.figure()
ax2 = fig2.add_subplot(1, 1, 1)
for obj in env.objects:
opos = obj.xywh
ocol = obj.rgb
sur_col = make_darker(ocol, 0.8)
mark_bb(obs3, opos, color=sur_col)
ax2.imshow(obs3)
ax2.set_xticks([])
ax2.set_yticks([])
plt.show()
fig3 = plt.figure()
ax3 = fig3.add_subplot(1, 1, 1)
for obj in env.objects_v:
opos = obj.xywh
ocol = obj.rgb
sur_col = make_darker(ocol, 0.8)
mark_bb(obs4, opos, color=sur_col)
ax3.imshow(obs4)
ax3.set_xticks([])
ax3.set_yticks([])
plt.show()
"""
env.close()
df = pd.DataFrame(dataset, columns=['INDEX', 'RAM', 'HUD', 'VIS'])
makedirs("data/datasets/", exist_ok=True)
prefix = f"{opts.game}_dqn" if opts.dqn else f"{opts.game}_random"
df.to_csv(f"data/datasets/{prefix}.csv", index=False)
pickle.dump(v_objs, open(f"data/datasets/{prefix}_objects_v.pkl", "wb"))
pickle.dump(r_objs, open(f"data/datasets/{prefix}_objects_r.pkl", "wb"))
pickle.dump(frames, open(f"data/datasets/{prefix}_frames.pkl", "wb"))
print(f"Finished {opts.game}")