-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenv_test.py
More file actions
41 lines (31 loc) · 1.34 KB
/
env_test.py
File metadata and controls
41 lines (31 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from game.game import Game
from game.player import Player_Random, Player_RL
import gymnasium as gym
from gymnasium.utils import RecordConstructorArgs
from gymnasium.utils.env_checker import check_env
class SafeActionWrapper(gym.Wrapper, RecordConstructorArgs):
def __init__(self, env):
super(SafeActionWrapper, self).__init__(env)
def step(self, action):
# Here, you can implement a mechanism to modify invalid actions
# For example, you could pick a valid action if the given one is invalid
if not self.is_valid_action(action):
action = self.get_valid_action(action)
return self.env.step(action)
def is_valid_action(self, action):
# Implement your logic to check if an action is valid
return action in self.env.get_masked_action_space()
def get_valid_action(self, action):
# Implement your logic to return a valid action if the given one is invalid
return self.env.get_masked_action_space()[0]
p1 = Player_Random("p1")
p2 = Player_RL("p2")
game = Game("test_map_v0", [p1, p2])
env = gym.make("game/RiskEnv-V0", game=game, agent_player=p2, render_mode="human")
# w_env = SafeActionWrapper(env)
# print("Check environment begin")
# check_env(w_env)
# print("Check environment end")
obs, info = env.reset()
obs_fl = env.unwrapped.flatten_obs(obs)
print(obs_fl)