-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_onedemoPERenv_system.py
More file actions
122 lines (95 loc) · 3.84 KB
/
test_onedemoPERenv_system.py
File metadata and controls
122 lines (95 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os, time
import pickle
import numpy as np
from system3 import *
# Single demonstration
demos_water = pickle.load(open("demos_water_gold.pk", "rb"))
demo = demos_water['1layer'][1]
# Test environments
test_env = pickle.load(open("maps__test.pk", "rb"))
train_env = pickle.load(open("maps__train.pk", "rb"))
# Our method (using demo)
system1 = System1Adapted()
environment_handler = EnvironmentHandler()
system1.environment_handler = environment_handler
system2 = System2()
# Dict type
dict_type = "oracle"
if dict_type == "oracle":
system2.rule_dict = system2.rule_dict_oracle
elif dict_type == "demo":
demo_type_strings = ["1layer", "2layer", "3layer", "gem_gold", "grass_gold", "iron_gold", "stone_gold", "water_gold", "wood_gold"]
for demo_string in demo_type_strings:
demos_rule_dict = pickle.load(open("demos_" + demo_string + ".pk", "rb"))
demo_rule_dict = np.random.choice(demos_rule_dict['1layer'])
rule_sequence, reachability_set_sequence, event_position_sequence = system2.use_demo(demo_rule_dict, system1)
#correct_rules, compounded_rules, incorrect_rules = system2.analyse_rule_base()
#print(correct_rules, compounded_rules, incorrect_rules )
elif dict_type == "demo_explore":
if os.path.exists("rule_dict_demo_explore_multiple_3_100_20.pk"):
system2.rule_dict = pickle.load(open("rule_dict_demo_explore_multiple_3_100_20.pk", "rb"))
else:
demo_type_strings = ["1layer", "2layer", "3layer", "gem_gold", "grass_gold", "iron_gold", "stone_gold", "water_gold", "wood_gold"]
for demo_string in demo_type_strings:
demos_rule_dict = pickle.load(open("demos_" + demo_string + ".pk", "rb"))
demo_rule_dict = np.random.choice(demos_rule_dict['1layer'])
rule_sequence, reachability_set_sequence, event_position_sequence = system2.use_demo(demo_rule_dict, system1)
correct, compounded, incorrect, total = system2.explore_env(pickle.load(open("custom_maps.pk", "rb")), system1, \
num_unique_envs = 3, num_envs = 100, max_skills_per_env = 20)
pickle.dump(system2.rule_dict, open("rule_dict_demo_explore_multiple_3_100_20.pk", "wb"))
else:
pass
# Add random exploration here
system3 = System3(system2.rule_dict)
# System 3, infers objective, generates graph guide, and outputs skill sequence for the new environment
rule_sequence, reachability_set_sequence, event_position_sequence = system2.use_demo(demo, system1)
objective = system3.infer_objective(rule_sequence, reachability_set_sequence, event_position_sequence)
success = 0
success_cases = []
failure = 0
failure_cases = []
total_time = 0
for i, env in enumerate(train_env):
#for i, env in enumerate(test_env):
start = time.time()
state = env
observable_env = system1.observation_function(fullstate(state))
try:
graph_guide = system3.get_dependency_graph_guide(observable_env)
except:
failure += 1
failure_cases.append(i)
continue
state.render()
state.render()
print("\n\n\n\nEnvironment number: {}\n\n\n\n\n".format(i))
possible_skill_sequences = system3.play(observable_env)
sequence_length = 0
try:
for skill_params, obj in possible_skill_sequences[0].skills_so_far:
observable_env = system1.observation_function(fullstate(state))
pos_x, pos_y = np.where(observable_env == 1)
action_seq = system1.use_object(observable_env, (pos_x[0], pos_y[0]), skill_params)
for a in action_seq:
_, state = state.step(a)
sequence_length += 1
if state.inventory[10] > 0:
end = time.time()
success += 1
success_cases.append((i, sequence_length))
total_time += end - start
else:
failure += 1
failure_cases.append(i)
except:
failure += 1
failure_cases.append(i)
state.render()
state.render()
print("\n\n\n\n")
for s in success_cases: print(s)
if success > 0:
print("Avg. time taken: {}, Success:{}, Failure:{}".format(total_time/success, success, failure))
else:
print("Success:{}, Failure:{}".format(success, failure))
import ipdb; ipdb.set_trace()