From a8d4dd7b147fc5a3cec995293bb656967d0ab60f Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Thu, 18 Aug 2022 15:25:46 +0100 Subject: [PATCH] Add testing for step api compatibility functions and wrapper (#3028) * Initial commit * Fixed tests and forced TimeLimit.truncated to always exist when truncated or terminated * Fix CI issues * pre-commit * Revert back to old language * Revert changes to step api wrapper --- gym/utils/env_checker.py | 2 +- gym/utils/step_api_compatibility.py | 172 +++++++++------------ gym/wrappers/time_limit.py | 7 +- tests/utils/test_step_api_compatibility.py | 166 ++++++++++++++++++++ tests/utils/test_terminated_truncated.py | 91 ----------- tests/wrappers/test_autoreset.py | 1 + 6 files changed, 247 insertions(+), 192 deletions(-) create mode 100644 tests/utils/test_step_api_compatibility.py delete mode 100644 tests/utils/test_terminated_truncated.py diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 95355e2a6e5..5a5564b571b 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool: return data_1.keys() == data_2.keys() and all( data_equivalence(data_1[k], data_2[k]) for k in data_1.keys() ) - elif isinstance(data_1, tuple): + elif isinstance(data_1, (tuple, list)): return len(data_1) == len(data_2) and all( data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) ) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 05986e0b76f..85d9b031147 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -36,66 +36,41 @@ def step_to_new_api( assert len(step_returns) == 4 observations, rewards, dones, infos = step_returns - terminateds = [] - truncateds = [] - if not is_vector_env: - dones = [dones] - - for i in range(len(dones)): - # For every condition, handling - info single env / info vector env (list) / info vector env (dict) - - # TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done` - if (not is_vector_env and "TimeLimit.truncated" not in infos) or ( - is_vector_env - and ( - ( - isinstance(infos, list) - and "TimeLimit.truncated" not in infos[i] - ) # vector env, list info api - or ( - "TimeLimit.truncated" not in infos - or ( - "TimeLimit.truncated" in infos - and not infos["TimeLimit.truncated"][i] - ) - ) - # vector env, dict info api, for env i, vector mask `_TimeLimit.truncated` is not considered, to be compatible with envpool - # For env i, `TimeLimit.truncated` not being present is treated same as being present and set to False. - # therefore, terminated=True, truncated=True simultaneously is not allowed while using compatibility functions - # with vector info - ) - ): - terminateds.append(dones[i]) - truncateds.append(False) - - # This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not. - elif ( - infos["TimeLimit.truncated"] - if not is_vector_env - else ( - infos["TimeLimit.truncated"][i] - if isinstance(infos, dict) - else infos[i]["TimeLimit.truncated"] - ) - ): - assert dones[i] - terminateds.append(False) - truncateds.append(True) - else: - # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, - # but it also exceeded maximum timesteps at the same step. However to be compatible with envpool, and to be backward compatible - # truncated is set to False here. - assert dones[i] - terminateds.append(True) - truncateds.append(False) - - return ( - observations, - rewards, - np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0], - np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0], - infos, - ) + # Cases to handle - info single env / info vector env (list) / info vector env (dict) + if is_vector_env is False: + truncated = infos.pop("TimeLimit.truncated", False) + return ( + observations, + rewards, + dones and not truncated, + dones and truncated, + infos, + ) + elif isinstance(infos, list): + truncated = np.array( + [info.pop("TimeLimit.truncated", False) for info in infos] + ) + return ( + observations, + rewards, + np.logical_and(dones, np.logical_not(truncated)), + np.logical_and(dones, truncated), + infos, + ) + elif isinstance(infos, dict): + num_envs = len(dones) + truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool)) + return ( + observations, + rewards, + np.logical_and(dones, np.logical_not(truncated)), + np.logical_and(dones, truncated), + infos, + ) + else: + raise TypeError( + f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}" + ) def step_to_old_api( @@ -111,44 +86,45 @@ def step_to_old_api( return step_returns else: assert len(step_returns) == 5 - observations, rewards, terminateds, truncateds, infos = step_returns - dones = [] - if not is_vector_env: - terminateds = [terminateds] - truncateds = [truncateds] - - n_envs = len(terminateds) - - for i in range(n_envs): - dones.append(terminateds[i] or truncateds[i]) - if truncateds[i]: - if is_vector_env: - # handle vector infos for dict and list - if isinstance(infos, dict): - if "TimeLimit.truncated" not in infos: - # TODO: This should ideally not be done manually and should use vector_env's _add_info() - infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) - infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) - - infos["TimeLimit.truncated"][i] = ( - not terminateds[i] or infos["TimeLimit.truncated"][i] - ) - infos["_TimeLimit.truncated"][i] = True - else: - # if vector info is a list - infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[ - i - ].get("TimeLimit.truncated", False) - else: - infos["TimeLimit.truncated"] = not terminateds[i] or infos.get( - "TimeLimit.truncated", False - ) - return ( - observations, - rewards, - np.array(dones, dtype=np.bool_) if is_vector_env else dones[0], - infos, - ) + observations, rewards, terminated, truncated, infos = step_returns + + # Cases to handle - info single env / info vector env (list) / info vector env (dict) + if is_vector_env is False: + if truncated or terminated: + infos["TimeLimit.truncated"] = truncated and not terminated + return ( + observations, + rewards, + terminated or truncated, + infos, + ) + elif isinstance(infos, list): + for info, env_truncated, env_terminated in zip( + infos, truncated, terminated + ): + if env_truncated or env_terminated: + info["TimeLimit.truncated"] = env_truncated and not env_terminated + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + elif isinstance(infos, dict): + if np.logical_or(np.any(truncated), np.any(terminated)): + infos["TimeLimit.truncated"] = np.logical_and( + truncated, np.logical_not(terminated) + ) + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + else: + raise TypeError( + f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}" + ) def step_api_compatibility( diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 8e9f67f4ae9..17481d68070 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -34,7 +34,7 @@ def __init__( Args: env: The environment to apply the wrapper - max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) + max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ super().__init__(env, new_step_api) @@ -63,7 +63,10 @@ def step(self, action): self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: - truncated = True + if self.new_step_api is True or terminated is False: + # As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both. + # Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding + truncated = True return step_api_compatibility( (observation, reward, terminated, truncated, info), diff --git a/tests/utils/test_step_api_compatibility.py b/tests/utils/test_step_api_compatibility.py new file mode 100644 index 00000000000..ade45892072 --- /dev/null +++ b/tests/utils/test_step_api_compatibility.py @@ -0,0 +1,166 @@ +import numpy as np +import pytest + +from gym.utils.env_checker import data_equivalence +from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api + + +@pytest.mark.parametrize( + "is_vector_env, done_returns, expected_terminated, expected_truncated", + ( + # Test each of the permutations for single environments with and without the old info + (False, (0, 0, False, {"Test-info": True}), False, False), + (False, (0, 0, False, {"TimeLimit.truncated": False}), False, False), + (False, (0, 0, True, {}), True, False), + (False, (0, 0, True, {"TimeLimit.truncated": True}), False, True), + (False, (0, 0, True, {"Test-info": True}), True, False), + # Test vectorise versions with both list and dict infos testing each permutation for sub-environments + ( + True, + ( + 0, + 0, + np.array([False, True, True]), + [{}, {}, {"TimeLimit.truncated": True}], + ), + np.array([False, True, False]), + np.array([False, False, True]), + ), + ( + True, + ( + 0, + 0, + np.array([False, True, True]), + {"TimeLimit.truncated": np.array([False, False, True])}, + ), + np.array([False, True, False]), + np.array([False, False, True]), + ), + # empty truncated info + ( + True, + ( + 0, + 0, + np.array([False, True]), + {}, + ), + np.array([False, True]), + np.array([False, False]), + ), + ), +) +def test_to_done_step_api( + is_vector_env, done_returns, expected_terminated, expected_truncated +): + _, _, terminated, truncated, info = step_to_new_api( + done_returns, is_vector_env=is_vector_env + ) + assert np.all(terminated == expected_terminated) + assert np.all(truncated == expected_truncated) + + if is_vector_env is False: + assert "TimeLimit.truncated" not in info + elif isinstance(info, list): + assert all("TimeLimit.truncated" not in sub_info for sub_info in info) + else: # isinstance(info, dict) + assert "TimeLimit.truncated" not in info + + roundtripped_returns = step_to_old_api( + (0, 0, terminated, truncated, info), is_vector_env=is_vector_env + ) + assert data_equivalence(done_returns, roundtripped_returns) + + +@pytest.mark.parametrize( + "is_vector_env, terminated_truncated_returns, expected_done, expected_truncated", + ( + (False, (0, 0, False, False, {"Test-info": True}), False, False), + (False, (0, 0, True, False, {}), True, False), + (False, (0, 0, False, True, {}), True, True), + # (False, (), True, True), # Not possible to encode in the old step api + # Test vector dict info + ( + True, + (0, 0, np.array([False, True, False]), np.array([False, False, True]), {}), + np.array([False, True, True]), + np.array([False, False, True]), + ), + # Test vector dict info with no truncation + ( + True, + (0, 0, np.array([False, True]), np.array([False, False]), {}), + np.array([False, True]), + np.array([False, False]), + ), + # Test vector list info + ( + True, + ( + 0, + 0, + np.array([False, True, False]), + np.array([False, False, True]), + [{"Test-Info": True}, {}, {}], + ), + np.array([False, True, True]), + np.array([False, False, True]), + ), + ), +) +def test_to_terminated_truncated_step_api( + is_vector_env, terminated_truncated_returns, expected_done, expected_truncated +): + _, _, done, info = step_to_old_api( + terminated_truncated_returns, is_vector_env=is_vector_env + ) + assert np.all(done == expected_done) + + if is_vector_env is False: + if expected_done: + assert info["TimeLimit.truncated"] == expected_truncated + else: + assert "TimeLimit.truncated" not in info + elif isinstance(info, list): + for sub_info, env_done, env_truncated in zip( + info, expected_done, expected_truncated + ): + if env_done: + assert sub_info["TimeLimit.truncated"] == env_truncated + else: + assert "TimeLimit.truncated" not in sub_info + else: # isinstance(info, dict) + if np.any(expected_done): + assert np.all(info["TimeLimit.truncated"] == expected_truncated) + else: + assert "TimeLimit.truncated" not in info + + roundtripped_returns = step_to_new_api( + (0, 0, done, info), is_vector_env=is_vector_env + ) + assert data_equivalence(terminated_truncated_returns, roundtripped_returns) + + +def test_edge_case(): + # When converting between the two-step APIs this is not possible in a single case + # terminated=True and truncated=True -> done=True and info={} + # We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail + _, _, done, info = step_to_old_api((0, 0, True, True, {})) + assert done is True + assert info == {"TimeLimit.truncated": False} + + # Test with vector dict info + _, _, done, info = step_to_old_api( + (0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True + ) + assert np.all(done) + assert info == {"TimeLimit.truncated": np.array([False])} + + # Test with vector list info + _, _, done, info = step_to_old_api( + (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]), + is_vector_env=True, + ) + assert np.all(done) + assert info == [{"Test-Info": True, "TimeLimit.truncated": False}] diff --git a/tests/utils/test_terminated_truncated.py b/tests/utils/test_terminated_truncated.py deleted file mode 100644 index e74fdc85378..00000000000 --- a/tests/utils/test_terminated_truncated.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest - -import gym -from gym.spaces import Discrete -from gym.vector import AsyncVectorEnv, SyncVectorEnv -from gym.wrappers import TimeLimit - - -# An environment where termination happens after 20 steps -class DummyEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - self.terminal_timestep = 20 - - self.timestep = 0 - - def step(self, action): - self.timestep += 1 - terminated = True if self.timestep >= self.terminal_timestep else False - truncated = False - - return 0, 0, terminated, truncated, {} - - def reset(self): - self.timestep = 0 - return 0 - - -@pytest.mark.parametrize("time_limit", [10, 20, 30]) -def test_terminated_truncated(time_limit): - test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True) - - terminated = False - truncated = False - test_env.reset() - while not (terminated or truncated): - _, _, terminated, truncated, _ = test_env.step(0) - - if test_env.terminal_timestep < time_limit: - assert terminated - assert not truncated - elif test_env.terminal_timestep == time_limit: - assert ( - terminated - ), "`terminated` should be True even when termination and truncation happen at the same step" - assert ( - truncated - ), "`truncated` should be True even when termination and truncation occur at same step " - else: - assert not terminated - assert truncated - - -def test_terminated_truncated_vector(): - env0 = TimeLimit(DummyEnv(), 10, new_step_api=True) - env1 = TimeLimit(DummyEnv(), 20, new_step_api=True) - env2 = TimeLimit(DummyEnv(), 30, new_step_api=True) - - async_env = AsyncVectorEnv( - [lambda: env0, lambda: env1, lambda: env2], new_step_api=True - ) - async_env.reset() - terminateds = [False, False, False] - truncateds = [False, False, False] - counter = 0 - while not all([x or y for x, y in zip(terminateds, truncateds)]): - counter += 1 - _, _, terminateds, truncateds, _ = async_env.step( - async_env.action_space.sample() - ) - print(counter) - assert counter == 20 - assert all(terminateds == [False, True, True]) - assert all(truncateds == [True, True, False]) - - sync_env = SyncVectorEnv( - [lambda: env0, lambda: env1, lambda: env2], new_step_api=True - ) - sync_env.reset() - terminateds = [False, False, False] - truncateds = [False, False, False] - counter = 0 - while not all([x or y for x, y in zip(terminateds, truncateds)]): - counter += 1 - _, _, terminateds, truncateds, _ = sync_env.step( - async_env.action_space.sample() - ) - assert counter == 20 - assert all(terminateds == [False, True, True]) - assert all(truncateds == [True, True, False]) diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index e4ed3f9b593..6598d919c64 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -138,6 +138,7 @@ def test_autoreset_wrapper_autoreset(): "count": 0, "final_observation": np.array([3]), "final_info": {"count": 3}, + "TimeLimit.truncated": False, } obs, reward, done, info = env.step(action)