diff --git a/nes_py/app/play_human.py b/nes_py/app/play_human.py index 6ca16571..db58e83e 100644 --- a/nes_py/app/play_human.py +++ b/nes_py/app/play_human.py @@ -62,15 +62,15 @@ def play_human(env: gym.Env, callback=None): # reset if the environment is done if done: done = False - state = env.reset() + state, _ = env.reset() viewer.show(env.unwrapped.screen) # unwrap the action based on pressed relevant keys action = keys_to_action.get(viewer.pressed_keys, _NOP) - next_state, reward, done, _ = env.step(action) + next_state, reward, done, truncated, _ = env.step(action) viewer.show(env.unwrapped.screen) # pass the observation data through the callback if callback is not None: - callback(state, action, reward, done, next_state) + callback(state, action, reward, done, truncated, next_state) state = next_state # shutdown if the escape key is pressed if viewer.is_escape_pressed: diff --git a/nes_py/app/play_random.py b/nes_py/app/play_random.py index d2fc4d8e..1d3e5cd5 100644 --- a/nes_py/app/play_random.py +++ b/nes_py/app/play_random.py @@ -19,9 +19,9 @@ def play_random(env, steps): progress = tqdm(range(steps)) for _ in progress: if done: - _ = env.reset() + _, _ = env.reset() action = env.action_space.sample() - _, reward, done, info = env.step(action) + _, reward, done, _, info = env.step(action) progress.set_postfix(reward=reward, info=info) env.render() except KeyboardInterrupt: diff --git a/nes_py/nes_env.py b/nes_py/nes_env.py index 35333f5a..110bc8aa 100644 --- a/nes_py/nes_env.py +++ b/nes_py/nes_env.py @@ -4,13 +4,27 @@ import itertools import os import sys + import gym +from gym.core import ObsType, RenderFrame from gym.spaces import Box from gym.spaces import Discrete import numpy as np from ._rom import ROM from ._image_viewer import ImageViewer +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + SupportsFloat, + Tuple, + TypeVar, + Union, +) # the path to the directory this file is in _MODULE_PATH = os.path.dirname(__file__) @@ -24,7 +38,6 @@ except IndexError: raise OSError('missing static lib_nes_env*.so library!') - # setup the argument and return types for Width _LIB.Width.argtypes = None _LIB.Width.restype = ctypes.c_uint @@ -59,7 +72,6 @@ _LIB.Close.argtypes = [ctypes.c_void_p] _LIB.Close.restype = None - # height in pixels of the NES screen SCREEN_HEIGHT = _LIB.Height() # width in pixels of the NES screen @@ -71,11 +83,9 @@ # create a type for the screen tensor matrix from C++ SCREEN_TENSOR = ctypes.c_byte * int(np.prod(SCREEN_SHAPE_32_BIT)) - # create a type for the RAM vector from C++ RAM_VECTOR = ctypes.c_byte * 0x800 - # create a type for the controller buffers from C++ CONTROLLER_VECTOR = ctypes.c_byte * 1 @@ -94,10 +104,10 @@ class NESEnv(gym.Env): # observation space for the environment is static across all instances observation_space = Box( - low=0, - high=255, - shape=SCREEN_SHAPE_24_BIT, - dtype=np.uint8 + low=0, + high=255, + shape=SCREEN_SHAPE_24_BIT, + dtype=np.uint8 ) # action space is a bitmap of button press values for the 8 NES buttons @@ -145,6 +155,8 @@ def __init__(self, rom_path): self._has_backup = False # setup a done flag self.done = True + # truncated + self.truncated = False # setup the controllers, screen, and RAM buffers self.controllers = [self._controller_buffer(port) for port in range(2)] self.screen = self._screen_buffer() @@ -243,7 +255,7 @@ def seed(self, seed=None): # return the list of seeds used by RNG(s) in the environment return [seed] - def reset(self, seed=None, options=None, return_info=None): + def reset(self, seed=None, options=None, return_info=None) -> Tuple[ObsType, dict]: """ Reset the state of the environment and returns an initial observation. @@ -253,7 +265,9 @@ def reset(self, seed=None, options=None, return_info=None): return_info (any): unused Returns: - state (np.ndarray): next frame as a result of the given action + a tuple + state (np.ndarray): next frame as a result of the given action + info dict: Return the info after a step occurs """ # Set the seed. @@ -270,13 +284,13 @@ def reset(self, seed=None, options=None, return_info=None): # set the done flag to false self.done = False # return the screen from the emulator - return self.screen + return self.screen, self._get_info() def _did_reset(self): """Handle any RAM hacking after a reset occurs.""" pass - def step(self, action): + def step(self, action) -> Tuple[ObsType, float, bool, bool, dict]: """ Run one frame of the NES and return the relevant observation data. @@ -304,6 +318,7 @@ def step(self, action): self.done = bool(self._get_done()) # get the info for this step info = self._get_info() + self.truncated = self._get_truncated() # call the after step callback self._did_step(self.done) # bound the reward in [min, max] @@ -312,7 +327,7 @@ def step(self, action): elif reward > self.reward_range[1]: reward = self.reward_range[1] # return the screen from the emulator and other relevant data - return self.screen, reward, self.done, info + return self.screen, reward, self.done, self.truncated, info def _get_reward(self): """Return the reward after a step occurs.""" @@ -322,6 +337,10 @@ def _get_done(self): """Return True if the episode is over, False otherwise.""" return False + def _get_truncated(self): + """Return True if truncated """ + return False + def _get_info(self): """Return the info after a step occurs.""" return {} @@ -352,7 +371,7 @@ def close(self): if self.viewer is not None: self.viewer.close() - def render(self, mode='human'): + def render(self, mode='human') -> Optional[Union[RenderFrame, List[RenderFrame]]]: """ Render the environment. @@ -378,9 +397,9 @@ def render(self, mode='human'): caption = self.spec.id # create the ImageViewer to display frames self.viewer = ImageViewer( - caption=caption, - height=SCREEN_HEIGHT, - width=SCREEN_WIDTH, + caption=caption, + height=SCREEN_HEIGHT, + width=SCREEN_WIDTH, ) # show the screen on the image viewer self.viewer.show(self.screen) @@ -401,7 +420,7 @@ def get_keys_to_action(self): ord('a'), # left ord('s'), # down ord('w'), # up - ord('\r'), # start + ord('\r'), # start ord(' '), # select ord('p'), # B ord('o'), # A @@ -427,4 +446,4 @@ def get_action_meanings(self): # explicitly define the outward facing API of this module -__all__ = [NESEnv.__name__] +__all__ = [NESEnv.__name__] \ No newline at end of file diff --git a/nes_py/tests/test_multiple_makes.py b/nes_py/tests/test_multiple_makes.py index 8764e48d..2e578cb5 100755 --- a/nes_py/tests/test_multiple_makes.py +++ b/nes_py/tests/test_multiple_makes.py @@ -24,9 +24,9 @@ def play(steps): done = True for _ in range(steps): if done: - _ = env.reset() + _, _ = env.reset() action = env.action_space.sample() - _, _, done, _ = env.step(action) + _, _, done, _, _ = env.step(action) # close the environment env.close() @@ -45,7 +45,7 @@ class ShouldMakeMultipleEnvironmentsParallel(object): def test(self): procs = [None] * self.num_execs - args = (self.steps, ) + args = (self.steps,) # spawn the parallel instances for idx in range(self.num_execs): procs[idx] = self.parallel_initializer(target=play, args=args) @@ -82,6 +82,6 @@ def test(self): for _ in range(self.steps): for idx in range(self.num_envs): if dones[idx]: - _ = envs[idx].reset() + _, _ = envs[idx].reset() action = envs[idx].action_space.sample() - _, _, dones[idx], _ = envs[idx].step(action) + _, _, dones[idx], _, _ = envs[idx].step(action) diff --git a/nes_py/tests/test_nes_env.py b/nes_py/tests/test_nes_env.py index 3e24c6e9..b5995af2 100644 --- a/nes_py/tests/test_nes_env.py +++ b/nes_py/tests/test_nes_env.py @@ -79,7 +79,7 @@ def test(self): for _ in range(500): if done: # reset the environment and check the output value - state = env.reset() + state, _ = env.reset() self.assertIsInstance(state, np.ndarray) # sample a random action and check it action = env.action_space.sample() @@ -87,12 +87,13 @@ def test(self): # take a step and check the outputs output = env.step(action) self.assertIsInstance(output, tuple) - self.assertEqual(4, len(output)) + self.assertEqual(5, len(output)) # check each output - state, reward, done, info = output + state, reward, done, truncated, info = output self.assertIsInstance(state, np.ndarray) self.assertIsInstance(reward, float) self.assertIsInstance(done, bool) + self.assertIsInstance(truncated, bool) self.assertIsInstance(info, dict) # check the render output render = env.render('rgb_array') @@ -108,9 +109,9 @@ def test(self): for _ in range(250): if done: - state = env.reset() + state, _ = env.reset() done = False - state, _, done, _ = env.step(0) + state, _, done, _, _ = env.step(0) backup = state.copy() @@ -120,9 +121,9 @@ def test(self): if done: state = env.reset() done = False - state, _, done, _ = env.step(0) + state, _, done, _, _ = env.step(0) self.assertFalse(np.array_equal(backup, state)) env._restore() self.assertTrue(np.array_equal(backup, env.screen)) - env.close() + env.close() \ No newline at end of file diff --git a/scripts/run.py b/scripts/run.py index 90abfae9..ad8f94e7 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -7,9 +7,9 @@ try: for _ in tqdm.tqdm(range(5000)): if done: - state = env.reset() + state, _ = env.reset() done = False else: - state, reward, done, info = env.step(env.action_space.sample()) + state, reward, done, truncated, info = env.step(env.action_space.sample()) except KeyboardInterrupt: pass diff --git a/setup.py b/setup.py index d50a5f74..5f178d8f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup( name='nes_py', - version='8.2.1', + version='8.2.2', description='An NES Emulator and OpenAI Gym interface', long_description=README, long_description_content_type='text/markdown',