Skip to content

Commit 8ddbdba

Browse files
committed
fix arena bugs
1 parent 9a43301 commit 8ddbdba

File tree

10 files changed

+99
-8
lines changed

10 files changed

+99
-8
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,7 @@ conda-upload:
4949
./scripts/conda_upload.sh
5050

5151
doc:
52-
./scripts/gen_api_docs.sh
52+
./scripts/gen_api_docs.sh
53+
54+
upload-codecov:
55+
codecov --file coverage.xml -t $(CODECOV_TOKEN)

examples/arena/run_arena.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def run_arena(
2626
seed=0,
2727
total_games: int = 10,
2828
max_game_onetime: int = 5,
29+
use_tqdm: bool = True,
2930
):
3031
env_wrappers = [RecordWinner]
3132
if render:
3233
from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender
3334

3435
env_wrappers.append(TictactoeRender)
3536

36-
arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=True)
37+
arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=use_tqdm)
3738

3839
agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent")
3940
agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent")
@@ -52,4 +53,4 @@ def run_arena(
5253

5354
if __name__ == "__main__":
5455
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10)
55-
# run_arena(render=True, parallel=True, seed=1, total_games=10, max_game_onetime=2)
56+
# run_arena(render=False, parallel=False, seed=1, total_games=1, max_game_onetime=1,use_tqdm=False)

examples/snake/jidi_random_vs_openrl_random.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def run_arena(
2828
seed=0,
2929
total_games: int = 10,
3030
max_game_onetime: int = 5,
31+
use_tqdm: bool = True,
3132
):
3233
env_wrappers = [RecordWinner]
3334

@@ -36,7 +37,7 @@ def run_arena(
3637
f"snakes_{player_num}v{player_num}",
3738
env_wrappers=env_wrappers,
3839
render=render,
39-
use_tqdm=True,
40+
use_tqdm=use_tqdm,
4041
)
4142

4243
agent1 = JiDiAgent("./submissions/random_agent", player_num=player_num)
@@ -55,4 +56,12 @@ def run_arena(
5556

5657

5758
if __name__ == "__main__":
58-
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5)
59+
# run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5)
60+
run_arena(
61+
render=False,
62+
parallel=False,
63+
seed=0,
64+
total_games=1,
65+
max_game_onetime=1,
66+
use_tqdm=False,
67+
)

openrl/envs/PettingZoo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs):
3535
from pettingzoo.classic import tictactoe_v3
3636

3737
env = tictactoe_v3.env(render_mode=render_mode)
38+
3839
else:
3940
raise NotImplementedError
4041
return env

openrl/envs/snake/snake_pettingzoo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def action_space(self, agent):
8383
return deepcopy(self._action_spaces[agent])
8484

8585
def observe(self, agent):
86-
return self.raw_obs[self.agent_name_to_slice[agent]]
86+
obs = self.raw_obs[self.agent_name_to_slice[agent]]
87+
return obs
8788

8889
def reset(
8990
self,

openrl/selfplay/opponents/random_opponent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def _sample_random_action(
4141
):
4242
action_space = self.env.action_space(player_name)
4343
if isinstance(action_space, list):
44+
if not isinstance(observation, list):
45+
observation = [observation]
46+
4447
action = []
48+
4549
for obs, space in zip(observation, action_space):
4650
mask = obs.get("action_mask", None)
4751
action.append(space.sample(mask))

openrl/supports/opendata/utils/opendata_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def data_server_wrapper(fp):
4848

4949
def load_dataset(data_path: str, split: str):
5050
from datasets import load_from_disk
51+
5152
if Path(data_path).exists():
5253
dataset = load_from_disk("{}/{}".format(data_path, split))
5354
elif "data_server:" in data_path:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_extra_requires() -> dict:
6969
"retro": ["gym-retro"],
7070
"super_mario": ["gym-super-mario-bros"],
7171
}
72+
req["test"].extend(req["selfplay"])
7273
return req
7374

7475

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import os
19+
import sys
20+
21+
import pytest
22+
23+
from openrl.arena import make_arena
24+
from openrl.arena.agents.local_agent import LocalAgent
25+
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
26+
27+
28+
def run_arena(
29+
render: bool = False,
30+
parallel: bool = True,
31+
seed=0,
32+
total_games: int = 10,
33+
max_game_onetime: int = 5,
34+
):
35+
env_wrappers = [RecordWinner]
36+
if render:
37+
from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender
38+
39+
env_wrappers.append(TictactoeRender)
40+
41+
arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=False)
42+
43+
agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
44+
agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
45+
46+
arena.reset(
47+
agents={"agent1": agent1, "agent2": agent2},
48+
total_games=total_games,
49+
max_game_onetime=max_game_onetime,
50+
seed=seed,
51+
)
52+
result = arena.run(parallel=parallel)
53+
arena.close()
54+
print(result)
55+
return result
56+
57+
58+
@pytest.mark.unittest
59+
def test_seed():
60+
seed = 0
61+
test_time = 3
62+
pre_result = None
63+
for parallel in [False, True]:
64+
for i in range(test_time):
65+
result = run_arena(seed=seed, parallel=parallel, total_games=20)
66+
if pre_result is not None:
67+
assert pre_result == result, f"parallel={parallel}, seed={seed}"
68+
pre_result = result
69+
70+
71+
if __name__ == "__main__":
72+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

tests/test_supports/test_opendata/test_opendata.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,5 @@ def test_data_abs_path():
3030
assert data_abs_path(data_path) == data_path
3131

3232

33-
34-
3533
if __name__ == "__main__":
3634
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)