diff --git a/.clang_complete b/.clang_complete index 14e577af..9e2a3043 100644 --- a/.clang_complete +++ b/.clang_complete @@ -1 +1,2 @@ -Ines_py/nes/include +-Ines_py/nes/src diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..429359ca --- /dev/null +++ b/Makefile @@ -0,0 +1,41 @@ +UV ?= $(shell which uv) +PYTHON ?= $(shell which python3) + +# build everything +all: test deployment + +# build the SimpleNES C++ code +lib_emu: + $(MAKE) -C nes_py/nes $(MAKEFLAGS) + mv nes_py/nes/libemulator.so nes_py/emulator.so + +install: + $(UV) pip install . + +# run the Python test suite +test: install + (cd nes_py/tests && $(PYTHON) -m unittest discover .) + + +# clean the build directory +clean: + $(MAKE) -C nes_py/nes clean + rm -rf build/ .eggs/ *.egg-info/ || true + find . -name "*.pyc" -delete + find . -name "__pycache__" -delete + find . -name ".sconsign.dblite" -delete + find . -name "build" | rm -rf + find . -name "emulator.so" -delete + +# build the deployment package +deployment: clean test + $(UV) build --sdist --wheel + +# ship the deployment package to PyPi +ship: test deployment + twine upload dist/* + +# Show configuration +show-config: + @echo "Python path: $(PYTHON)" + @echo "UV command: $(UV)" diff --git a/build_all_py_ver b/build_all_py_ver new file mode 100755 index 00000000..9e4e2f82 --- /dev/null +++ b/build_all_py_ver @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +export UV_PYTHON_PREFERENCE=only-managed + +if command -v uv &> /dev/null +then + curl -LsSf https://astral.sh/uv/install.sh | sh +fi + +for py_ver in 3.{8,9,10,11,12,13} +do + env_name=".pyenvs/${py_ver//.}" + uv venv -p $py_ver $env_name + source $env_name/bin/activate + uv pip install pip -r requirements.txt + make deployment + deactivate +done diff --git a/makefile b/makefile deleted file mode 100644 index 66895c12..00000000 --- a/makefile +++ /dev/null @@ -1,28 +0,0 @@ -# build everything -all: test deployment - -# build the LaiNES CPP code -lib_nes_env: - scons -C nes_py/nes - mv nes_py/nes/lib_nes_env*.so nes_py - -# run the Python test suite -test: lib_nes_env - python3 -m unittest discover . - -# clean the build directory -clean: - rm -rf build/ dist/ .eggs/ *.egg-info/ || true - find . -name "*.pyc" -delete - find . -name "__pycache__" -delete - find . -name ".sconsign.dblite" -delete - find . -name "build" | rm -rf - find . -name "lib_nes_env.so" -delete - -# build the deployment package -deployment: clean - python3 setup.py sdist bdist_wheel - -# ship the deployment package to PyPi -ship: test deployment - twine upload dist/* diff --git a/nes_py/__init__.py b/nes_py/__init__.py index 5d92ec4a..d5807de7 100644 --- a/nes_py/__init__.py +++ b/nes_py/__init__.py @@ -1,6 +1,5 @@ -"""The nes-py NES emulator for Python 2 & 3.""" -from .nes_env import NESEnv - +"""The nes-py NES emulator for Python 3.""" +from nes_py.nes_env import NESEnv # explicitly define the outward facing API of this package -__all__ = [NESEnv.__name__] +__all__ = [NESEnv.__name__] \ No newline at end of file diff --git a/__main__.py b/nes_py/__main__.py similarity index 100% rename from __main__.py rename to nes_py/__main__.py diff --git a/nes_py/_image_viewer.py b/nes_py/_image_viewer.py index 5492a0d2..a18b7732 100644 --- a/nes_py/_image_viewer.py +++ b/nes_py/_image_viewer.py @@ -1,103 +1,161 @@ """A simple class for viewing images using pyglet.""" - - -class ImageViewer(object): - """A simple class for viewing images using pyglet.""" - - def __init__(self, caption, height, width, - monitor_keyboard=False, - relevant_keys=None - ): - """ - Initialize a new image viewer. +import threading +from typing import Dict +from typing import List +from typing import Tuple +from typing import Optional +from typing import ClassVar +from dataclasses import dataclass + +import numpy as np + + + +@dataclass(init=False) +class ImageViewer: + """A class for displaying images using pyglet window system. + + This class provides functionality to create and manage a window for displaying + RGB image arrays. It supports keyboard monitoring and window management. + + Attributes: + caption (str): The title of the window. + height (int): The height of the window in pixels. + width (int): The width of the window in pixels. + monitor_keyboard (bool): Whether to monitor keyboard events. + relevant_keys (Optional[List[int]]): List of key codes to monitor. If None, all keys are monitored. + _pressed_keys (List[int]): Internal list of currently pressed keys. + _is_escape_pressed (bool): Internal flag for escape key state. + _window (Optional[BaseWindow]): Internal pyglet window instance. + """ + + caption: str + height: int + width: int + monitor_keyboard: bool + relevant_keys: Optional[List[int]] + _pressed_keys: List[int] + _is_escape_pressed: bool + _window: Optional + _pyglet: Optional + + # Map pyglet key codes to their native equivalents + KEY_MAP: Dict[int, int] + + def __init__( + self, + caption: str, + height: int, + width: int, + monitor_keyboard: bool = False, + relevant_keys: Optional[List[int]] = None, + ) -> None: + """Initialize a new image viewer instance. Args: - caption (str): the caption/title for the window - height (int): the height of the window - width (int): the width of the window - monitor_keyboard: whether to monitor events from the keyboard - relevant_keys: the relevant keys to monitor events from - - Returns: - None + caption: The title of the window. + height: The height of the window in pixels. + width: The width of the window in pixels. + monitor_keyboard: Whether to monitor keyboard events. + relevant_keys: List of key codes to monitor. If None, all keys are monitored. + + Raises: + RuntimeError: If initialized from a non-main thread. """ - # detect if rendering from python threads and fail - import threading if threading.current_thread() is not threading.main_thread(): - msg = 'rendering from python threads is not supported' - raise RuntimeError(msg) - # import pyglet within class scope to resolve issues with how pyglet - # interacts with OpenGL while using multiprocessing + raise RuntimeError('rendering from python threads is not supported') + import pyglet - self.pyglet = pyglet - # a mapping from pyglet key identifiers to native identifiers - self.KEY_MAP = { - self.pyglet.window.key.ENTER: ord('\r'), - self.pyglet.window.key.SPACE: ord(' '), - } - self.caption = caption + self._pyglet = pyglet + + self.KEY_MAP: Dict[int, int] = { + pyglet.window.key.ENTER: ord('\r'), + pyglet.window.key.SPACE: ord(' '), + } + + self._window = None self.height = height self.width = width - self.monitor_keyboard = monitor_keyboard - self.relevant_keys = relevant_keys - self._window = None - self._pressed_keys = [] + self.caption = caption + self._pressed_keys = list() self._is_escape_pressed = False + self.relevant_keys = relevant_keys + self.monitor_keyboard = monitor_keyboard @property - def is_open(self): - """Return a boolean determining if this window is open.""" + def is_open(self) -> bool: + """Check if the window is currently open. + + Returns: + bool: True if window is open, False otherwise. + """ return self._window is not None @property - def is_escape_pressed(self): - """Return True if the escape key is pressed.""" + def is_escape_pressed(self) -> bool: + """Check if the escape key is currently pressed. + + Returns: + bool: True if escape key is pressed, False otherwise. + """ return self._is_escape_pressed @property - def pressed_keys(self): - """Return a sorted list of the pressed keys.""" - return tuple(sorted(self._pressed_keys)) + def pressed_keys(self) -> Tuple[int, ...]: + """Get currently pressed keys. - def _handle_key_event(self, symbol, is_press): + Returns: + Tuple[int, ...]: A sorted tuple of key codes currently being pressed. """ - Handle a key event. - - Args: - symbol: the symbol in the event - is_press: whether the event is a press or release + return tuple(sorted(self._pressed_keys)) - Returns: - None + def _handle_key_event(self, symbol: int, is_press: bool) -> None: + """Handle keyboard press/release events. + Args: + symbol: The key code of the pressed/released key. + is_press: True if key was pressed, False if released. """ - # remap the key to the expected domain symbol = self.KEY_MAP.get(symbol, symbol) - # check if the symbol is the escape key - if symbol == self.pyglet.window.key.ESCAPE: + + if symbol == key.ESCAPE: self._is_escape_pressed = is_press return - # make sure the symbol is relevant + if self.relevant_keys is not None and symbol not in self.relevant_keys: return - # handle the press / release by appending / removing the key to pressed + if is_press: self._pressed_keys.append(symbol) else: self._pressed_keys.remove(symbol) - def on_key_press(self, symbol, modifiers): - """Respond to a key press on the keyboard.""" + def on_key_press(self, symbol: int, modifiers: int) -> None: + """Handle key press events from pyglet. + + Args: + symbol: The key code of the pressed key. + modifiers: Bitwise combination of any keyboard modifiers currently pressed. + """ self._handle_key_event(symbol, True) - def on_key_release(self, symbol, modifiers): - """Respond to a key release on the keyboard.""" + def on_key_release(self, symbol: int, modifiers: int) -> None: + """Handle key release events from pyglet. + + Args: + symbol: The key code of the released key. + modifiers: Bitwise combination of any keyboard modifiers currently pressed. + """ self._handle_key_event(symbol, False) - def open(self): - """Open the window.""" + def open(self) -> None: + """Create and open the pyglet window. + + Creates a new window with the configured caption, dimensions and vsync settings. + If keyboard monitoring is enabled, sets up the key event handlers. + """ # create a window for this image viewer instance - self._window = self.pyglet.window.Window( + self._window = self._pyglet.window.Window( caption=self.caption, height=self.height, width=self.width, @@ -110,44 +168,58 @@ def open(self): self._window.event(self.on_key_press) self._window.event(self.on_key_release) - def close(self): - """Close the window.""" + def close(self) -> None: + """Close the pyglet window if it's open.""" if self.is_open: self._window.close() self._window = None - def show(self, frame): - """ - Show an array of pixels on the window. + def show(self, frame: np.ndarray) -> None: + """Display an RGB image array in the window. + + Opens the window if it isn't already open, clears it, and displays the new frame + scaled to fit the window dimensions. Args: - frame (numpy.ndarray): the frame to show on the window + frame: RGB image array of shape (height, width, 3). - Returns: - None + Raises: + ValueError: If frame doesn't have exactly 3 dimensions. """ - # check that the frame has the correct dimensions if len(frame.shape) != 3: raise ValueError('frame should have shape with only 3 dimensions') - # open the window if it isn't open already + if not self.is_open: self.open() - # prepare the window for the next frame + self._window.clear() self._window.switch_to() self._window.dispatch_events() - # create an image data object - image = self.pyglet.image.ImageData( + + image = self._pyglet.image.ImageData( frame.shape[1], frame.shape[0], 'RGB', frame.tobytes(), - pitch=frame.shape[1]*-3 + pitch=frame.shape[1] * -3 + ) + + texture = image.get_texture() + self._pyglet.gl.glTexParameteri( + self._pyglet.gl.GL_TEXTURE_2D, + self._pyglet.gl.GL_TEXTURE_MAG_FILTER, + self._pyglet.gl.GL_NEAREST ) - # send the image to the window - image.blit(0, 0, width=self._window.width, height=self._window.height) + self._pyglet.gl.glTexParameteri( + self._pyglet.gl.GL_TEXTURE_2D, + self._pyglet.gl.GL_TEXTURE_MIN_FILTER, + self._pyglet.gl.GL_NEAREST + ) + + image.blit(0, 0, width=self._window.width, height=self._window.height) self._window.flip() + # explicitly define the outward facing API of this module __all__ = [ImageViewer.__name__] diff --git a/nes_py/_rom.py b/nes_py/_rom.py index 0de44689..e5bc0103 100644 --- a/nes_py/_rom.py +++ b/nes_py/_rom.py @@ -4,17 +4,24 @@ - http://wiki.nesdev.com/w/index.php/INES """ import os +from typing import ClassVar +from dataclasses import dataclass + import numpy as np -class ROM(object): +@dataclass(frozen=True) +class ROM: """An abstraction of the NES Read-Only Memory (ROM).""" - # the magic bytes expected at the first four bytes of the header. + # The magic bytes expected at the first four bytes of the header. # It spells "NES" - _MAGIC = np.array([0x4E, 0x45, 0x53, 0x1A]) + MAGIC: ClassVar[np.ndarray] = np.array([0x4E, 0x45, 0x53, 0x1A], dtype=np.uint8) + raw_data: np.ndarray + - def __init__(self, rom_path): + @classmethod + def from_path(cls, rom_path: str) -> 'ROM': """ Initialize a new ROM. @@ -28,22 +35,27 @@ def __init__(self, rom_path): # make sure the rom path is a string if not isinstance(rom_path, str): raise TypeError('rom_path must be of type: str.') + # make sure the rom path exists if not os.path.exists(rom_path): msg = 'rom_path points to non-existent file: {}.'.format(rom_path) raise ValueError(msg) - # read the binary data in the .nes ROM file - self.raw_data = np.fromfile(rom_path, dtype='uint8') + + # Load the binary data in the .nes ROM file + rom = cls(raw_data=np.fromfile(rom_path, dtype='uint8')) + # ensure the first 4 bytes are 0x4E45531A (NES) - if not np.array_equal(self._magic, self._MAGIC): + if not np.array_equal(rom.header[:4], cls.MAGIC): raise ValueError('ROM missing magic number in header.') - if self._zero_fill != 0: + + if rom.header[11:].sum() != 0: raise ValueError("ROM header zero fill bytes are not zero.") + + return rom # # MARK: Header # - @property def header(self): """Return the header of the ROM file as bytes.""" @@ -57,12 +69,12 @@ def _magic(self): @property def prg_rom_size(self): """Return the size of the PRG ROM in KB.""" - return 16 * self.header[4] + return 16 * int(self.header[4]) @property def chr_rom_size(self): """Return the size of the CHR ROM in KB.""" - return 8 * self.header[5] + return 8 * int(self.header[5]) @property def flags_6(self): diff --git a/nes_py/app/cli.py b/nes_py/app/cli.py index 762ea2d2..a1381b58 100644 --- a/nes_py/app/cli.py +++ b/nes_py/app/cli.py @@ -1,8 +1,9 @@ """Command line interface to nes-py NES emulator.""" import argparse -from .play_human import play_human -from .play_random import play_random -from ..nes_env import NESEnv +from nes_py.nes_env import NESEnv +from nes_py.app.play_human import play_human +from nes_py.app.play_random import play_random + def _get_args(): diff --git a/nes_py/app/play_human.py b/nes_py/app/play_human.py index 6ca16571..0a79623f 100644 --- a/nes_py/app/play_human.py +++ b/nes_py/app/play_human.py @@ -1,15 +1,17 @@ """A method to play gym environments using human IO inputs.""" -import gym import time + +import gymnasium as gym from pyglet import clock -from .._image_viewer import ImageViewer + +from nes_py._image_viewer import ImageViewer # the sentinel value for "No Operation" _NOP = 0 -def play_human(env: gym.Env, callback=None): +def play_human(env: gym.Env, callback=None) -> None: """ Play the environment using keyboard as a human. @@ -41,13 +43,14 @@ def play_human(env: gym.Env, callback=None): env.observation_space.shape[0], # height env.observation_space.shape[1], # width monitor_keyboard=True, - relevant_keys=set(sum(map(list, keys_to_action.keys()), [])) + relevant_keys=set(sum(map(list, keys_to_action.keys()), [])+[ord('8'), ord('9')]) ) # create a done flag for the environment done = True # prepare frame rate limiting - target_frame_duration = 1 / env.metadata['video.frames_per_second'] + target_frame_duration = 1 / env.metadata['render_fps'] last_frame_time = 0 + snapshot = None # start the main game loop try: while True: @@ -65,12 +68,24 @@ def play_human(env: gym.Env, callback=None): state = env.reset() viewer.show(env.unwrapped.screen) # unwrap the action based on pressed relevant keys - action = keys_to_action.get(viewer.pressed_keys, _NOP) - next_state, reward, done, _ = env.step(action) + action = _NOP + if ord('8') in viewer.pressed_keys: + snapshot = env.dump_state() + print(f'len(snapshot): {len(snapshot)}') + elif ord('9') in viewer.pressed_keys: + env.load_state(snapshot) + else: + action = keys_to_action.get(viewer.pressed_keys, _NOP) + + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + viewer.show(env.unwrapped.screen) + # pass the observation data through the callback if callback is not None: callback(state, action, reward, done, next_state) + state = next_state # shutdown if the escape key is pressed if viewer.is_escape_pressed: diff --git a/nes_py/app/play_random.py b/nes_py/app/play_random.py index d2fc4d8e..f6bf1b4e 100644 --- a/nes_py/app/play_random.py +++ b/nes_py/app/play_random.py @@ -1,8 +1,9 @@ """Methods for playing the game randomly, or as a human.""" from tqdm import tqdm +import gymnasium as gym -def play_random(env, steps): +def play_random(env: gym.Env, steps: int) -> None: """ Play the environment making uniformly random decisions. @@ -21,7 +22,9 @@ def play_random(env, steps): if done: _ = env.reset() action = env.action_space.sample() - _, reward, done, info = env.step(action) + _, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + progress.set_postfix(reward=reward, info=info) env.render() except KeyboardInterrupt: diff --git a/nes_py/nes/Makefile b/nes_py/nes/Makefile new file mode 100644 index 00000000..2fc7cf7f --- /dev/null +++ b/nes_py/nes/Makefile @@ -0,0 +1,55 @@ +# Get Python-specific paths using shell commands +VENV_SITE_PACKAGES := $(shell python3 -c 'import site; print(site.getsitepackages()[0])') +PYBIND11_PATH := $(VENV_SITE_PACKAGES)/pybind11/include +PYTHON_INCLUDE := $(shell python3 -c 'import sysconfig; print(sysconfig.get_config_var("INCLUDEPY"))') +PYTHON_LIBDIR := $(shell python3 -c 'import sysconfig; print(sysconfig.get_config_var("LIBDIR"))') +PYTHON_VERSION := $(shell python3 -c 'import sysconfig; print(sysconfig.get_config_var("LDVERSION") or sysconfig.get_config_var("VERSION"))') + +# Compiler and flags +CXX := g++ +CXXFLAGS := -std=c++14 -O3 -pipe -fPIC -Wno-unused-value +INCLUDES := -I$(dir $(lastword $(MAKEFILE_LIST)))include -I$(PYBIND11_PATH) -I$(PYTHON_INCLUDE) + +# Platform-specific settings and common LDFLAGS +# Don't link python library directly - use undefined dynamic_lookup +LDFLAGS := + +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Darwin) + LDFLAGS += -undefined dynamic_lookup + ifeq ($(shell uname -m),arm64) + CXXFLAGS += -arch arm64 + LDFLAGS += -arch arm64 + endif +endif + +# Source files +SRC_DIR := $(dir $(lastword $(MAKEFILE_LIST)))src +BUILD_DIR := $(dir $(lastword $(MAKEFILE_LIST)))build +SRCS := $(shell find $(SRC_DIR) -name '*.cpp') +OBJS := $(SRCS:$(SRC_DIR)/%.cpp=$(BUILD_DIR)/%.o) + +# Target library +TARGET := $(dir $(lastword $(MAKEFILE_LIST)))libemulator.so + +# Default target +all: $(BUILD_DIR) $(TARGET) + +# Create build directory +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) + +# Build the shared library +$(TARGET): $(OBJS) + $(CXX) $(CXXFLAGS) -shared $(LDFLAGS) -o $@ $^ + +# Compile source files +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# Clean build files +clean: + rm -rf $(BUILD_DIR) $(TARGET) + +.PHONY: all clean \ No newline at end of file diff --git a/nes_py/nes/SConstruct b/nes_py/nes/SConstruct deleted file mode 100644 index b788c74a..00000000 --- a/nes_py/nes/SConstruct +++ /dev/null @@ -1,31 +0,0 @@ -"""The compilation script for this project using SCons.""" -from os import environ - - -# create a separate build directory -VariantDir('build', 'src', duplicate=0) - - -# the compiler and linker flags for the C++ environment -FLAGS = [ - '-std=c++1y', - '-O3', - '-pipe', -] - - -# Create the C++ environment -ENV = Environment( - ENV=environ, - CXX='g++', - CPPFLAGS=['-Wno-unused-value'], - CXXFLAGS=FLAGS, - LINKFLAGS=FLAGS, - CPPPATH=['#include'], -) - - -# Locate all the C++ source files -SRC = Glob('build/*.cpp') + Glob('build/*/*.cpp') -# Create a shared library (it will add "lib" to the front automatically) -ENV.SharedLibrary('_nes_env.so', SRC) diff --git a/nes_py/nes/include/common.hpp b/nes_py/nes/include/common.hpp index 828a1501..77f744c3 100644 --- a/nes_py/nes/include/common.hpp +++ b/nes_py/nes/include/common.hpp @@ -11,6 +11,9 @@ // resolve an issue with MSVC overflow during compilation (Windows) #define _CRT_DECLARE_NONSTDC_NAMES 0 #include +#include +#include +#include namespace NES { @@ -21,6 +24,55 @@ typedef uint16_t NES_Address; /// A shortcut for a single pixel in memory typedef uint32_t NES_Pixel; +template +class static_vector : public std::array { +private: + std::size_t current_size{0}; + std::size_t reserved_size{N}; + +public: + static_vector() : std::array() {} + static_vector(std::initializer_list list) : std::array(list) {} + + void push_back(const T& value) { + if (current_size >= reserved_size) + throw std::length_error("static_vector: container is full"); + (*this)[current_size++] = value; + } + + void reserve(std::size_t new_capacity) { + if (new_capacity > N) + throw std::length_error("static_vector: cannot reserve beyond max capacity"); + reserved_size = new_capacity; + current_size = std::min(current_size, reserved_size); + } + + void resize(std::size_t new_size) { + if (new_size > reserved_size) + throw std::length_error("static_vector: cannot resize beyond reserved capacity"); + current_size = new_size; + } + + void clear() { current_size = 0; } + + // Iterator support + typename std::array::iterator begin() noexcept { + return this->std::array::begin(); + } + + typename std::array::iterator end() noexcept { + return begin() + current_size; // Only iterate up to current_size + } + + typename std::array::const_iterator begin() const noexcept { + return this->std::array::begin(); + } + + typename std::array::const_iterator end() const noexcept { + return begin() + current_size; // Only iterate up to current_size + } +}; + } // namespace NES -#endif // COMMON_HPP +#endif // COMMON_HPP \ No newline at end of file diff --git a/nes_py/nes/include/emulator.hpp b/nes_py/nes/include/emulator.hpp index ac70a471..8970360c 100644 --- a/nes_py/nes/include/emulator.hpp +++ b/nes_py/nes/include/emulator.hpp @@ -8,7 +8,9 @@ #ifndef EMULATOR_HPP #define EMULATOR_HPP +#include #include +#include #include "common.hpp" #include "cartridge.hpp" #include "controller.hpp" @@ -19,34 +21,40 @@ namespace NES { -/// An NES Emulator and OpenAI Gym interface -class Emulator { - private: - /// The number of cycles in 1 frame - static const int CYCLES_PER_FRAME = 29781; - /// the virtual cartridge with ROM and mapper data - Cartridge cartridge; - /// the 2 controllers on the emulator - Controller controllers[2]; - +struct Core { /// the main data bus of the emulator MainBus bus; - /// the picture bus from the PPU of the emulator - PictureBus picture_bus; /// The emulator's CPU CPU cpu; /// the emulators' PPU PPU ppu; - - /// the main data bus of the emulator - MainBus backup_bus; /// the picture bus from the PPU of the emulator - PictureBus backup_picture_bus; - /// The emulator's CPU - CPU backup_cpu; - /// the emulators' PPU - PPU backup_ppu; + PictureBus picture_bus; + void initialize(Controller* const controllers); + void reset(); + void set_mapper(Mapper *mapper); + Mapper* get_mapper() const { return bus.get_mapper(); } + void ppu_step(NESFrameBufferT* const framebuffer); + void step(NESFrameBufferT* const framebuffer); + + /// Copy state from another Core, preserving local mapper and callbacks. + /// This is safe to use with snapshots from destroyed emulators because + /// we only copy raw data, not pointers or function objects. + inline void copy_state_from(const Core& other) { + // Copy RAM state (preserves mapper and callbacks) + bus.copy_ram_from(other.bus); + // Copy CPU state (all plain data) + cpu = other.cpu; + // Copy PPU state (preserves vblank_callback) + ppu.copy_state_from(other.ppu); + // Copy picture bus RAM (preserves mapper) + picture_bus.copy_ram_from(other.picture_bus); + } +}; + +/// An NES Emulator and OpenAI Gym interface +class Emulator { public: /// The width of the NES screen in pixels static const int WIDTH = SCANLINE_VISIBLE_DOTS; @@ -63,13 +71,13 @@ class Emulator { /// /// @return a 32-bit pointer to the screen buffer's first address /// - inline NES_Pixel* get_screen_buffer() { return ppu.get_screen_buffer(); } + inline NESFrameBufferT* const get_screen_buffer() { return &framebuffer; } /// Return a 8-bit pointer to the RAM buffer's first address. /// /// @return a 8-bit pointer to the RAM buffer's first address /// - inline NES_Byte* get_memory_buffer() { return bus.get_memory_buffer(); } + inline NES_Byte* get_memory_buffer() { return core.bus.get_memory_buffer(); } /// Return a pointer to a controller port /// @@ -81,28 +89,42 @@ class Emulator { } /// Load the ROM into the NES. - inline void reset() { cpu.reset(bus); ppu.reset(); } + inline void reset() { core.cpu.reset(core.bus); core.ppu.reset(); } /// Perform a step on the emulator, i.e., a single frame. void step(); - /// Create a backup state on the emulator. - inline void backup() { - backup_bus = bus; - backup_picture_bus = picture_bus; - backup_cpu = cpu; - backup_ppu = ppu; + /// Perform a step on the PPU, i.e., a single frame. + void ppu_step(); + + /// Create a snapshot state on the emulator. + inline void snapshot(Core* const core) { + *core = this->core; } - /// Restore the backup state on the emulator. - inline void restore() { - bus = backup_bus; - picture_bus = backup_picture_bus; - cpu = backup_cpu; - ppu = backup_ppu; + /// Restore the snapshot state on the emulator. + /// Uses copy_state_from to safely copy only data, preserving local + /// mapper pointer and callbacks which may contain dangling references + /// in snapshots from destroyed emulators. + inline void restore(const Core* const core) { + this->core.copy_state_from(*core); } + + private: + /// The number of cycles in 1 frame + static const int CYCLES_PER_FRAME = 29781; + + /// the core of the emulator + Core core; + /// the virtual cartridge with ROM and mapper data + Cartridge cartridge; + /// the 2 controllers on the emulator + Controller controllers[2]; + + /// the rendering framebuffer of the emulator + NESFrameBufferT framebuffer; }; } // namespace NES -#endif // EMULATOR_HPP +#endif // EMULATOR_HPP \ No newline at end of file diff --git a/nes_py/nes/include/main_bus.hpp b/nes_py/nes/include/main_bus.hpp index dab40615..ff0f39d8 100644 --- a/nes_py/nes/include/main_bus.hpp +++ b/nes_py/nes/include/main_bus.hpp @@ -50,9 +50,9 @@ typedef std::unordered_map IORegisterT class MainBus { private: /// The RAM on the main bus - std::vector ram; + static_vector ram; /// The extended RAM (if the mapper has extended RAM) - std::vector extended_ram; + static_vector extended_ram; /// a pointer to the mapper on the cartridge Mapper* mapper; /// a map of IO registers to callback methods for writes @@ -62,7 +62,7 @@ class MainBus { public: /// Initialize a new main bus. - MainBus() : ram(0x800, 0), mapper(nullptr) { } + MainBus() : mapper(nullptr) { } /// Return a 8-bit pointer to the RAM buffer's first address. /// @@ -91,6 +91,12 @@ class MainBus { /// void set_mapper(Mapper* mapper); + /// Get the mapper pointer. + /// + /// @return the current mapper pointer + /// + inline Mapper* get_mapper() const { return mapper; } + /// Set a callback for when writes occur. inline void set_write_callback(IORegisters reg, WriteCallback callback) { write_callbacks.insert({reg, callback}); @@ -103,6 +109,13 @@ class MainBus { /// Return a pointer to the page in memory. const NES_Byte* get_page_pointer(NES_Byte page); + + /// Copy RAM state from another MainBus (preserves mapper and callbacks) + inline void copy_ram_from(const MainBus& other) { + ram = other.ram; + extended_ram = other.extended_ram; + // Note: mapper and callbacks are NOT copied + } }; } // namespace NES diff --git a/nes_py/nes/include/picture_bus.hpp b/nes_py/nes/include/picture_bus.hpp index b3cee67c..66b611a5 100644 --- a/nes_py/nes/include/picture_bus.hpp +++ b/nes_py/nes/include/picture_bus.hpp @@ -19,17 +19,17 @@ namespace NES { class PictureBus { private: /// the VRAM on the picture bus - std::vector ram; + static_vector ram; /// indexes where they start in RAM vector std::size_t name_tables[4] = {0, 0, 0, 0}; /// the palette for decoding RGB tuples - std::vector palette; + static_vector palette; /// a pointer to the mapper on the cartridge Mapper* mapper; public: /// Initialize a new picture bus. - PictureBus() : ram(0x800), palette(0x20), mapper(nullptr) { } + PictureBus() : mapper(nullptr) { } /// Read a byte from an address on the VRAM. /// @@ -54,6 +54,14 @@ class PictureBus { this->mapper = mapper; update_mirroring(); } + /// Copy RAM/palette state from another PictureBus (preserves mapper) + inline void copy_ram_from(const PictureBus& other) { + ram = other.ram; + for (int i = 0; i < 4; i++) name_tables[i] = other.name_tables[i]; + palette = other.palette; + // Note: mapper is NOT copied + } + /// Read a color index from the palette. /// /// @param address the address of the palette color diff --git a/nes_py/nes/include/ppu.hpp b/nes_py/nes/include/ppu.hpp index 23baeb43..bf5e7f90 100644 --- a/nes_py/nes/include/ppu.hpp +++ b/nes_py/nes/include/ppu.hpp @@ -10,6 +10,7 @@ #include "common.hpp" #include "picture_bus.hpp" +#include namespace NES { @@ -24,15 +25,17 @@ const int SCANLINE_END_CYCLE = 341; /// The last scanline per frame const int FRAME_END_SCANLINE = 261; +typedef NES_Pixel NESFrameBufferT[VISIBLE_SCANLINES][SCANLINE_VISIBLE_DOTS]; + /// The Picture Processing Unit (PPU) for the NES class PPU { private: /// The callback to fire when entering vertical blanking mode std::function vblank_callback; /// The OAM memory (sprites) - std::vector sprite_memory; + static_vector sprite_memory; /// OAM memory (sprites) for the next scanline - std::vector scanline_sprites; + static_vector scanline_sprites; /// The current pipeline state of the PPU enum State { @@ -98,17 +101,12 @@ class PPU { /// The value to increment the data address by NES_Address data_address_increment; - /// The internal screen data structure as a vector representation of a - /// matrix of height matching the visible scans lines and width matching - /// the number of visible scan line dots - NES_Pixel screen[VISIBLE_SCANLINES][SCANLINE_VISIBLE_DOTS]; - public: /// Initialize a new PPU. - PPU() : sprite_memory(64 * 4) { } + PPU() { } /// Perform a single cycle on the PPU. - void cycle(PictureBus& bus); + void cycle(PictureBus& bus, NESFrameBufferT* const screen); /// Reset the PPU. void reset(); @@ -118,6 +116,38 @@ class PPU { vblank_callback = cb; } + /// Copy state from another PPU (preserves vblank_callback) + inline void copy_state_from(const PPU& other) { + // Save local callback + auto local_callback = vblank_callback; + // Copy all data + sprite_memory = other.sprite_memory; + scanline_sprites = other.scanline_sprites; + pipeline_state = other.pipeline_state; + cycles = other.cycles; + scanline = other.scanline; + is_even_frame = other.is_even_frame; + is_vblank = other.is_vblank; + is_sprite_zero_hit = other.is_sprite_zero_hit; + data_address = other.data_address; + temp_address = other.temp_address; + fine_x_scroll = other.fine_x_scroll; + is_first_write = other.is_first_write; + data_buffer = other.data_buffer; + sprite_data_address = other.sprite_data_address; + is_showing_sprites = other.is_showing_sprites; + is_showing_background = other.is_showing_background; + is_hiding_edge_sprites = other.is_hiding_edge_sprites; + is_hiding_edge_background = other.is_hiding_edge_background; + is_long_sprites = other.is_long_sprites; + is_interrupting = other.is_interrupting; + background_page = other.background_page; + sprite_page = other.sprite_page; + data_address_increment = other.data_address_increment; + // Restore local callback (DON'T copy from other) + vblank_callback = local_callback; + } + /// TODO: doc void do_DMA(const NES_Byte* page_ptr); @@ -179,9 +209,6 @@ class PPU { inline void set_OAM_data(NES_Byte value) { sprite_memory[sprite_data_address++] = value; } - - /// Return a pointer to the screen buffer. - inline NES_Pixel* get_screen_buffer() { return *screen; } }; } // namespace NES diff --git a/nes_py/nes/src/cartridge.cpp b/nes_py/nes/src/cartridge.cpp index 00d5218c..fed38707 100644 --- a/nes_py/nes/src/cartridge.cpp +++ b/nes_py/nes/src/cartridge.cpp @@ -6,32 +6,61 @@ // #include +#include #include "cartridge.hpp" -#include "log.hpp" namespace NES { void Cartridge::loadFromFile(std::string path) { - // create a stream to load the ROM file - std::ifstream romFile(path, std::ios_base::binary | std::ios_base::in); - // create a byte vector for the iNES header - std::vector header; - header.resize(0x10); - romFile.read(reinterpret_cast(&header[0]), 0x10); - // read internal data + std::ifstream file(path, std::ios::binary); + if (!file.good()) { + throw std::runtime_error("ROM file not found or not readable: " + path); + } + + // Read and validate iNES header (16 bytes) + std::vector header(0x10); + if (!file.read(reinterpret_cast(header.data()), 0x10)) { + throw std::runtime_error("ROM file too small to contain iNES header: " + path); + } + + // Validate magic bytes "NES\x1A" + if (header[0] != 'N' || header[1] != 'E' || header[2] != 'S' || header[3] != 0x1A) { + throw std::runtime_error("Invalid iNES header (missing NES magic bytes): " + path); + } + + // Parse header + NES_Byte prg_banks = header[4]; + NES_Byte chr_banks = header[5]; + + if (prg_banks == 0) { + throw std::runtime_error("Invalid ROM: PRG ROM size is 0: " + path); + } + name_table_mirroring = header[6] & 0xB; mapper_number = ((header[6] >> 4) & 0xf) | (header[7] & 0xf0); has_extended_ram = header[6] & 0x2; - // read PRG-ROM 16KB banks - NES_Byte banks = header[4]; - prg_rom.resize(0x4000 * banks); - romFile.read(reinterpret_cast(&prg_rom[0]), 0x4000 * banks); - // read CHR-ROM 8KB banks - NES_Byte vbanks = header[5]; - if (!vbanks) - return; - chr_rom.resize(0x2000 * vbanks); - romFile.read(reinterpret_cast(&chr_rom[0]), 0x2000 * vbanks); + bool has_trainer = header[6] & 0x04; + + // Skip trainer if present (512 bytes) + if (has_trainer) { + file.seekg(512, std::ios::cur); + } + + // Read PRG-ROM (16KB per bank) + size_t prg_size = 0x4000 * prg_banks; + prg_rom.resize(prg_size); + if (!file.read(reinterpret_cast(prg_rom.data()), prg_size)) { + throw std::runtime_error("ROM file truncated while reading PRG ROM: " + path); + } + + // Read CHR-ROM (8KB per bank) if present + if (chr_banks > 0) { + size_t chr_size = 0x2000 * chr_banks; + chr_rom.resize(chr_size); + if (!file.read(reinterpret_cast(chr_rom.data()), chr_size)) { + throw std::runtime_error("ROM file truncated while reading CHR ROM: " + path); + } + } } } // namespace NES diff --git a/nes_py/nes/src/cpu.cpp b/nes_py/nes/src/cpu.cpp index e0dd4830..2e2997ac 100644 --- a/nes_py/nes/src/cpu.cpp +++ b/nes_py/nes/src/cpu.cpp @@ -556,8 +556,8 @@ void CPU::cycle(MainBus &bus) { // must be before ExecuteType0 if (implied(bus, op) || branch(bus, op) || type1(bus, op) || type2(bus, op) || type0(bus, op)) skip_cycles += OPERATION_CYCLES[op]; - else - std::cout << "failed to execute opcode: " << std::hex << +op << std::endl; + // Unofficial opcodes - silently ignore (they're common in NES games) } + } // namespace NES diff --git a/nes_py/nes/src/emulator.cpp b/nes_py/nes/src/emulator.cpp index d9655b01..74702680 100644 --- a/nes_py/nes/src/emulator.cpp +++ b/nes_py/nes/src/emulator.cpp @@ -9,15 +9,19 @@ #include "mapper_factory.hpp" #include "log.hpp" +#include + namespace NES { -Emulator::Emulator(std::string rom_path) { +void Core::initialize(Controller* const controllers) +{ // set the read callbacks bus.set_read_callback(PPUSTATUS, [&](void) { return ppu.get_status(); }); bus.set_read_callback(PPUDATA, [&](void) { return ppu.get_data(picture_bus); }); bus.set_read_callback(JOY1, [&](void) { return controllers[0].read(); }); bus.set_read_callback(JOY2, [&](void) { return controllers[1].read(); }); bus.set_read_callback(OAMDATA, [&](void) { return ppu.get_OAM_data(); }); + // set the write callbacks bus.set_write_callback(PPUCTRL, [&](NES_Byte b) { ppu.control(b); }); bus.set_write_callback(PPUMASK, [&](NES_Byte b) { ppu.set_mask(b); }); @@ -28,25 +32,82 @@ Emulator::Emulator(std::string rom_path) { bus.set_write_callback(OAMDMA, [&](NES_Byte b) { cpu.skip_DMA_cycles(); ppu.do_DMA(bus.get_page_pointer(b)); }); bus.set_write_callback(JOY1, [&](NES_Byte b) { controllers[0].strobe(b); controllers[1].strobe(b); }); bus.set_write_callback(OAMDATA, [&](NES_Byte b) { ppu.set_OAM_data(b); }); + // set the interrupt callback for the PPU ppu.set_interrupt_callback([&]() { cpu.interrupt(bus, CPU::NMI_INTERRUPT); }); +} + +void Core::reset() { + cpu.reset(bus); + ppu.reset(); +} + +void Core::set_mapper(Mapper *mapper) { + bus.set_mapper(mapper); + picture_bus.set_mapper(mapper); +} + +void Core::ppu_step(NESFrameBufferT* const framebuffer) { + // 3 PPU steps per CPU step + ppu.cycle(picture_bus, framebuffer); + ppu.cycle(picture_bus, framebuffer); + ppu.cycle(picture_bus, framebuffer); +} + +void Core::step(NESFrameBufferT* const framebuffer) { + // 3 PPU steps per CPU step + ppu_step(framebuffer); + cpu.cycle(bus); +} + +Emulator::Emulator(std::string rom_path) +{ + // set the read callbacks + core.bus.set_read_callback(PPUSTATUS, [&](void) { return core.ppu.get_status(); }); + core.bus.set_read_callback(PPUDATA, [&](void) { return core.ppu.get_data(core.picture_bus); }); + core.bus.set_read_callback(JOY1, [&](void) { return controllers[0].read(); }); + core.bus.set_read_callback(JOY2, [&](void) { return controllers[1].read(); }); + core.bus.set_read_callback(OAMDATA, [&](void) { return core.ppu.get_OAM_data(); }); + + // set the write callbacks + core.bus.set_write_callback(PPUCTRL, [&](NES_Byte b) { core.ppu.control(b); }); + core.bus.set_write_callback(PPUMASK, [&](NES_Byte b) { core.ppu.set_mask(b); }); + core.bus.set_write_callback(OAMADDR, [&](NES_Byte b) { core.ppu.set_OAM_address(b); }); + core.bus.set_write_callback(PPUADDR, [&](NES_Byte b) { core.ppu.set_data_address(b); }); + core.bus.set_write_callback(PPUSCROL, [&](NES_Byte b) { core.ppu.set_scroll(b); }); + core.bus.set_write_callback(PPUDATA, [&](NES_Byte b) { core.ppu.set_data(core.picture_bus, b); }); + core.bus.set_write_callback(OAMDMA, [&](NES_Byte b) { core.cpu.skip_DMA_cycles(); core.ppu.do_DMA(core.bus.get_page_pointer(b)); }); + core.bus.set_write_callback(JOY1, [&](NES_Byte b) { controllers[0].strobe(b); controllers[1].strobe(b); }); + core.bus.set_write_callback(OAMDATA, [&](NES_Byte b) { core.ppu.set_OAM_data(b); }); + + // set the interrupt callback for the PPU + core.ppu.set_interrupt_callback([&]() { core.cpu.interrupt(core.bus, CPU::NMI_INTERRUPT); }); + + // initialize the framebuffer to all black + std::memset(&framebuffer, 0, sizeof(framebuffer)); + // load the ROM from disk, expect that the Python code has validated it cartridge.loadFromFile(rom_path); + // create the mapper based on the mapper ID in the iNES header of the ROM - auto mapper = MapperFactory(&cartridge, [&](){ picture_bus.update_mirroring(); }); + auto mapper = MapperFactory(&cartridge, [&](){ core.picture_bus.update_mirroring(); }); + // give the IO buses a pointer to the mapper - bus.set_mapper(mapper); - picture_bus.set_mapper(mapper); + core.set_mapper(mapper); } + void Emulator::step() { // render a single frame on the emulator for (int i = 0; i < CYCLES_PER_FRAME; i++) { - // 3 PPU steps per CPU step - ppu.cycle(picture_bus); - ppu.cycle(picture_bus); - ppu.cycle(picture_bus); - cpu.cycle(bus); + core.step(&framebuffer); + } +} + +void Emulator::ppu_step() { + // render a single frame on the emulator + for (int i = 0; i < CYCLES_PER_FRAME; i++) { + core.ppu_step(&framebuffer); } } diff --git a/nes_py/nes/src/lib_nes_env.cpp b/nes_py/nes/src/lib_nes_env.cpp index e2d9dd06..addcd753 100644 --- a/nes_py/nes/src/lib_nes_env.cpp +++ b/nes_py/nes/src/lib_nes_env.cpp @@ -2,86 +2,742 @@ // File: lib_nes_env.cpp // Description: file describes the outward facing ctypes API for Python // +// CHANGELOG: - 2024-12-28: Changed from ctypes to pybind11 - Ali Mosavian +// - 2026-01-27: Added VectorEmulator for parallel stepping +// // Copyright (c) 2019 Christian Kauten. All rights reserved. // - -#include #include "common.hpp" #include "emulator.hpp" -// Windows-base systems -#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__) - // setup the module initializer. required to link visual studio C++ ctypes - void PyInit_lib_nes_env() { } - // setup the function modifier to export in the DLL - #define EXP __declspec(dllexport) -// Unix-like systems -#else - // setup the modifier as a dummy - #define EXP +#include +#include +#include +#include +#include +#include +#include + +// Platform-specific headers for CPU affinity +#ifdef __linux__ +#include +#include +#elif defined(__APPLE__) +#include +#include +#include +#elif defined(_WIN32) +#include #endif -// definitions of functions for the Python interface to access -extern "C" { - /// Return the width of the NES. - EXP int Width() { - return NES::Emulator::WIDTH; - } +#include +#include +#include +#include - /// Return the height of the NES. - EXP int Height() { - return NES::Emulator::HEIGHT; - } +namespace py = pybind11; - /// Initialize a new emulator and return a pointer to it - EXP NES::Emulator* Initialize(wchar_t* path) { - // convert the c string to a c++ std string data structure - std::wstring ws_rom_path(path); - std::string rom_path(ws_rom_path.begin(), ws_rom_path.end()); - // create a new emulator with the given ROM path - return new NES::Emulator(rom_path); - } +// ============================================================================= +// RAM Read Specification - for batch reading after step +// ============================================================================= - /// Return a pointer to a controller on the machine - EXP NES::NES_Byte* Controller(NES::Emulator* emu, int port) { - return emu->get_controller(port); - } +enum class RamReadType { + INT = 0, // Single byte as integer + BCD = 1 // Multiple bytes as BCD (Binary Coded Decimal) +}; - /// Return the pointer to the screen buffer - EXP NES::NES_Pixel* Screen(NES::Emulator* emu) { - return emu->get_screen_buffer(); - } +struct RamReadSpec { + uint16_t address; + uint8_t size; // 1-6 bytes + RamReadType type; + + RamReadSpec(uint16_t addr, uint8_t sz, RamReadType t) + : address(addr), size(sz), type(t) {} +}; - /// Return the pointer to the memory buffer - EXP NES::NES_Byte* Memory(NES::Emulator* emu) { - return emu->get_memory_buffer(); - } +// ============================================================================= +// CPU Affinity - Pin threads to specific cores for better cache locality +// ============================================================================= - /// Reset the emulator - EXP void Reset(NES::Emulator* emu) { - emu->reset(); - } +/// Pin the current thread to a specific CPU core. +/// On Linux: Uses pthread_setaffinity_np for hard affinity. +/// On macOS: Uses thread_policy_set (hint only, not guaranteed). +/// On Windows: Uses SetThreadAffinityMask. +inline void pin_thread_to_core(int core_id) { + int num_cores = std::thread::hardware_concurrency(); + if (num_cores == 0 || core_id < 0) return; + + // Wrap around if core_id exceeds available cores + core_id = core_id % num_cores; + +#ifdef __linux__ + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); + pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); +#elif defined(__APPLE__) + // macOS: thread affinity is only a hint, not enforced + thread_affinity_policy_data_t policy = { core_id }; + thread_policy_set(pthread_mach_thread_np(pthread_self()), + THREAD_AFFINITY_POLICY, + (thread_policy_t)&policy, + THREAD_AFFINITY_POLICY_COUNT); +#elif defined(_WIN32) + SetThreadAffinityMask(GetCurrentThread(), 1ULL << core_id); +#endif +} - /// Perform a discrete step in the emulator (i.e., 1 frame) - EXP void Step(NES::Emulator* emu) { - emu->step(); - } +// ============================================================================= +// VectorEmulator - Parallel NES emulation with zero-copy observations +// ============================================================================= +// +// Mirrors the NESEmulator interface but for multiple environments: +// - step(actions) - Step all emulators 1 frame +// - step_frames(actions, n) - Step all emulators n frames +// - screen_buffer() - Get list of zero-copy screen views +// - screen_buffer(idx) - Get single screen view +// - memory_buffer() - Get list of RAM views +// - memory_buffer(idx) - Get single RAM view +// +// Thread model: Persistent worker threads (one per emulator), condition +// variable synchronization. GIL is released during parallel stepping. +// +// ============================================================================= - /// Create a deep copy (i.e., a clone) of the given emulator - EXP void Backup(NES::Emulator* emu) { - emu->backup(); - } +class VectorEmulator { +private: + // Worker states - each worker only touches its own state (cache-line isolated) + static constexpr int STATE_IDLE = 0; + static constexpr int STATE_PENDING = 1; + static constexpr int STATE_DONE = 2; + static constexpr int STATE_EXIT = 3; + + // Cache line size for padding (typically 64 bytes on x86/ARM) + static constexpr size_t CACHE_LINE_SIZE = 64; + + // Cache-line aligned atomic to prevent false sharing + struct alignas(CACHE_LINE_SIZE) AlignedAtomic { + std::atomic state{STATE_IDLE}; + char padding[CACHE_LINE_SIZE - sizeof(std::atomic)]; + + AlignedAtomic() : state(STATE_IDLE) {} + }; - /// Create a deep copy (i.e., a clone) of the given emulator - EXP void Restore(NES::Emulator* emu) { - emu->restore(); +public: + VectorEmulator(const std::string& rom_path, int num_envs) + : num_envs_(num_envs), rom_path_(rom_path), ready_count_(0) { + + emulators_.reserve(num_envs); + worker_states_.reserve(num_envs); + worker_frames_.resize(num_envs, 1); + worker_timings_.resize(num_envs); + + for (int i = 0; i < num_envs; i++) { + emulators_.push_back(std::make_unique(rom_path)); + worker_states_.push_back(std::make_unique()); + } + + workers_.reserve(num_envs); + for (int i = 0; i < num_envs; i++) { + workers_.emplace_back(&VectorEmulator::worker_loop, this, i); + } + + // Wait for all workers to be ready (busy-wait on atomic counter) + while (ready_count_.load(std::memory_order_acquire) < num_envs) { + std::this_thread::yield(); + } } - - /// Close the emulator, i.e., purge it from memory - EXP void Close(NES::Emulator* emu) { - delete emu; + + ~VectorEmulator() { + // Signal all workers to exit + for (int i = 0; i < num_envs_; i++) { + worker_states_[i]->state.store(STATE_EXIT, std::memory_order_release); + } + + for (auto& w : workers_) { + if (w.joinable()) { + w.join(); + } + } } -} + + int num_envs() const { return num_envs_; } + + // Reset all emulators + void reset() { + for (auto& emu : emulators_) { + emu->reset(); + } + } + + // Reset single emulator + void reset_env(int idx) { + check_idx(idx); + while (worker_states_[idx]->state.load(std::memory_order_acquire) != STATE_IDLE) { + std::this_thread::yield(); + } + emulators_[idx]->reset(); + } + + // Step all emulators 1 frame in parallel (like NESEmulator.step()) + void step(py::array_t actions) { + step_impl(actions, 1); + } + + // Step a single emulator (synchronous, no threading) + void step_single(int idx, uint8_t action) { + check_idx(idx); + // Wait for any pending work on this emulator + while (worker_states_[idx]->state.load(std::memory_order_acquire) != STATE_IDLE) { + std::this_thread::yield(); + } + // Set action and step directly (no threading) + *emulators_[idx]->get_controller(0) = action; + emulators_[idx]->step(); + } + + // Get screen buffer for all emulators as list of zero-copy views + py::list screen_buffer() { + py::list result; + for (int i = 0; i < num_envs_; i++) { + result.append(screen_buffer_single(i)); + } + return result; + } + + // Get screen buffer for single emulator as zero-copy view + py::array_t screen_buffer_single(int idx) { + check_idx(idx); + const int HEIGHT = NES::Emulator::HEIGHT; + const int WIDTH = NES::Emulator::WIDTH; + + #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + return py::array_t( + {HEIGHT, WIDTH, 3}, + {WIDTH * 4, 4, -1}, // negative stride: BGR -> RGB + reinterpret_cast(emulators_[idx]->get_screen_buffer()) + 2, + py::capsule(emulators_[idx].get(), [](void*) {}) + ); + #else + return py::array_t( + {HEIGHT, WIDTH, 3}, + {WIDTH * 4, 4, 1}, + reinterpret_cast(emulators_[idx]->get_screen_buffer()) + 1, + py::capsule(emulators_[idx].get(), [](void*) {}) + ); + #endif + } + + // Get memory buffer for all emulators as list of zero-copy views + py::list memory_buffer() { + py::list result; + for (int i = 0; i < num_envs_; i++) { + result.append(memory_buffer_single(i)); + } + return result; + } + + // Get memory buffer for single emulator as zero-copy view + py::array_t memory_buffer_single(int idx) { + check_idx(idx); + return py::array_t( + {0x800}, + {1}, + reinterpret_cast(emulators_[idx]->get_memory_buffer()), + py::capsule(emulators_[idx].get(), [](void*) {}) + ); + } + + // Get controller buffer for single emulator as zero-copy view + py::array_t controller(int idx, int port = 0) { + check_idx(idx); + return py::array_t( + {1}, + {1}, + reinterpret_cast(emulators_[idx]->get_controller(port)), + py::capsule(emulators_[idx].get(), [](void*) {}) + ); + } + + // Dump state for single emulator + py::array_t dump_state(int idx) { + check_idx(idx); + + // Create a copy of the state data + auto* core = new NES::Core; + memset(core, 0, sizeof(NES::Core)); + emulators_[idx]->snapshot(core); + + // Create capsule to own the memory + py::capsule capsule(core, [](void* p) { + delete static_cast(p); + }); + + return py::array_t( + {sizeof(NES::Core)}, + {1}, + reinterpret_cast(core), + capsule + ); + } + + // Load state for single emulator + void load_state(int idx, const py::array_t& state) { + check_idx(idx); + // Wait for worker to be idle before modifying emulator state + while (worker_states_[idx]->state.load(std::memory_order_acquire) != STATE_IDLE) { + std::this_thread::yield(); + } + emulators_[idx]->restore(reinterpret_cast(state.request().ptr)); + emulators_[idx]->ppu_step(); + } + +private: + void check_idx(int idx) const { + if (idx < 0 || idx >= num_envs_) { + throw std::out_of_range( + "Environment index " + std::to_string(idx) + + " out of range [0, " + std::to_string(num_envs_) + ")" + ); + } + } + + void step_impl(py::array_t actions, int num_frames) { + using clock = std::chrono::high_resolution_clock; + auto t0 = clock::now(); + + auto actions_buf = actions.request(); + uint8_t* actions_ptr = static_cast(actions_buf.ptr); + + // Set actions and mark workers as pending (lock-free) + for (int i = 0; i < num_envs_; i++) { + *emulators_[i]->get_controller(0) = actions_ptr[i]; + worker_frames_[i] = num_frames; + worker_states_[i]->state.store(STATE_PENDING, std::memory_order_release); + } + + auto t1 = clock::now(); + + // Busy-wait for all workers to complete (GIL released, lock-free) + { + py::gil_scoped_release release; + + // Spin until all workers are done + while (true) { + bool all_done = true; + for (int i = 0; i < num_envs_; i++) { + if (worker_states_[i]->state.load(std::memory_order_acquire) != STATE_DONE) { + all_done = false; + break; + } + } + if (all_done) break; + + // Brief pause to reduce CPU spinning overhead + std::this_thread::yield(); + } + } + + auto t2 = clock::now(); + + // Mark workers idle (ready for next step) + for (int i = 0; i < num_envs_; i++) { + worker_states_[i]->state.store(STATE_IDLE, std::memory_order_release); + } + + auto t3 = clock::now(); + + // Read configured RAM values (batch read in C++) + read_ram_values(); + + auto t4 = clock::now(); + + // Accumulate timing stats + timing_setup_ns_ += std::chrono::duration_cast(t1 - t0).count(); + timing_wait_ns_ += std::chrono::duration_cast(t2 - t1).count(); + timing_idle_ns_ += std::chrono::duration_cast(t3 - t2).count(); + timing_ram_ns_ += std::chrono::duration_cast(t4 - t3).count(); + timing_calls_++; + } + + void worker_loop(int idx) { + using clock = std::chrono::high_resolution_clock; + + // Pin this worker thread to a core in round-robin fashion + // Worker 0 -> core 0, Worker 1 -> core 1, etc. (wraps around) + int num_cores = std::thread::hardware_concurrency(); + int target_core = (num_cores > 0) ? (idx % num_cores) : -1; + pin_thread_to_core(target_core); + worker_timings_[idx].pinned_core = target_core; + + // Signal that this worker is ready + ready_count_.fetch_add(1, std::memory_order_release); + + while (true) { + auto t0 = clock::now(); + + // Busy-wait for work (lock-free - each worker only checks its own state) + int state; + while (true) { + state = worker_states_[idx]->state.load(std::memory_order_acquire); + if (state == STATE_PENDING || state == STATE_EXIT) break; + std::this_thread::yield(); + } + + auto t1 = clock::now(); + + if (state == STATE_EXIT) { + return; + } + + // Step emulator (the actual work) + for (int f = 0; f < worker_frames_[idx]; f++) { + emulators_[idx]->step(); + } + + auto t2 = clock::now(); + + // Signal completion (lock-free - just update our own state) + worker_states_[idx]->state.store(STATE_DONE, std::memory_order_release); + + // Record timing + worker_timings_[idx].wait_ns += std::chrono::duration_cast(t1 - t0).count(); + worker_timings_[idx].step_ns += std::chrono::duration_cast(t2 - t1).count(); + worker_timings_[idx].calls++; + } + } + + int num_envs_; + std::string rom_path_; + std::vector> emulators_; + std::vector workers_; + std::vector> worker_states_; // Cache-line aligned to prevent false sharing + std::vector worker_frames_; + + // Worker ready synchronization (only used during startup) + std::atomic ready_count_; + + // RAM read configuration (set once, used every step) + std::vector ram_specs_; + std::vector ram_values_; // Shape: [num_envs * num_specs], row-major + int num_ram_specs_ = 0; + + // Timing instrumentation (nanoseconds) - main thread + uint64_t timing_setup_ns_ = 0; // Set actions + signal workers + uint64_t timing_wait_ns_ = 0; // Wait for workers to complete + uint64_t timing_idle_ns_ = 0; // Mark workers idle + uint64_t timing_ram_ns_ = 0; // Read RAM values + uint64_t timing_calls_ = 0; // Number of step calls + + // Per-worker timing (cache-line aligned to prevent false sharing) + struct alignas(64) WorkerTiming { + uint64_t wait_ns = 0; // Time waiting for work + uint64_t step_ns = 0; // Time doing emulator step + uint64_t calls = 0; // Number of step calls + int pinned_core = -1; // Core this worker is pinned to + char padding[64 - 32]; // Pad to cache line + }; + std::vector worker_timings_; + + // Read BCD value from RAM (e.g., score stored as 6 separate digits) + inline int32_t read_bcd(const uint8_t* ram, uint16_t addr, int size) const { + int32_t result = 0; + for (int i = 0; i < size; i++) { + result = result * 10 + ram[addr + i]; + } + return result; + } + + // Read RAM values for all emulators after step (called from main thread) + void read_ram_values() { + if (num_ram_specs_ == 0) return; + + for (int env = 0; env < num_envs_; env++) { + const uint8_t* ram = reinterpret_cast( + emulators_[env]->get_memory_buffer()); + int base = env * num_ram_specs_; + + for (int s = 0; s < num_ram_specs_; s++) { + const auto& spec = ram_specs_[s]; + if (spec.type == RamReadType::BCD) { + ram_values_[base + s] = read_bcd(ram, spec.address, spec.size); + } else { + ram_values_[base + s] = ram[spec.address]; + } + } + } + } + +public: + // Configure RAM addresses to read after each step + // specs: list of (address, size, type) where type is 0=INT, 1=BCD + void configure_ram_reads(const std::vector>& specs) { + ram_specs_.clear(); + ram_specs_.reserve(specs.size()); + + for (const auto& spec : specs) { + uint16_t addr = std::get<0>(spec); + uint8_t size = std::get<1>(spec); + int type = std::get<2>(spec); + ram_specs_.emplace_back(addr, size, + type == 1 ? RamReadType::BCD : RamReadType::INT); + } + + num_ram_specs_ = static_cast(ram_specs_.size()); + ram_values_.resize(num_envs_ * num_ram_specs_); + } + + // Get RAM values as numpy array, shape: (num_envs, num_specs) + py::array_t ram_values() const { + if (num_ram_specs_ == 0) { + // Return empty array with explicit shape + std::vector shape = {static_cast(num_envs_), 0}; + return py::array_t(shape); + } + + return py::array_t( + {num_envs_, num_ram_specs_}, + {num_ram_specs_ * static_cast(sizeof(int32_t)), static_cast(sizeof(int32_t))}, + ram_values_.data(), + py::capsule(ram_values_.data(), [](void*) {}) + ); + } + + // Get timing stats as dict and reset counters + py::dict get_timing_stats() { + py::dict stats; + if (timing_calls_ > 0) { + stats["calls"] = timing_calls_; + stats["setup_ms"] = timing_setup_ns_ / 1e6; + stats["wait_ms"] = timing_wait_ns_ / 1e6; + stats["idle_ms"] = timing_idle_ns_ / 1e6; + stats["ram_ms"] = timing_ram_ns_ / 1e6; + stats["total_ms"] = (timing_setup_ns_ + timing_wait_ns_ + timing_idle_ns_ + timing_ram_ns_) / 1e6; + } + // Reset main thread timing + timing_setup_ns_ = timing_wait_ns_ = timing_idle_ns_ = timing_ram_ns_ = timing_calls_ = 0; + return stats; + } + + // Get per-worker timing stats and reset counters + py::list get_worker_timing_stats() { + py::list workers; + for (int i = 0; i < num_envs_; i++) { + py::dict w; + w["worker"] = i; + w["core"] = worker_timings_[i].pinned_core; + w["calls"] = worker_timings_[i].calls; + w["wait_ms"] = worker_timings_[i].wait_ns / 1e6; + w["step_ms"] = worker_timings_[i].step_ns / 1e6; + if (worker_timings_[i].calls > 0) { + w["step_avg_us"] = (worker_timings_[i].step_ns / worker_timings_[i].calls) / 1e3; + } else { + w["step_avg_us"] = 0.0; + } + workers.append(w); + + // Reset + worker_timings_[i].wait_ns = 0; + worker_timings_[i].step_ns = 0; + worker_timings_[i].calls = 0; + } + return workers; + } +}; + +PYBIND11_MODULE(emulator, m) { + py::class_(m, "NESEmulator") + .def(py::init()) + + .def_property_readonly_static("width", [](py::object) { return NES::Emulator::WIDTH; }) + .def_property_readonly_static("height", [](py::object) { return NES::Emulator::HEIGHT; }) + + .def("reset", &NES::Emulator::reset, "Reset the emulator") + .def("step", &NES::Emulator::step, py::call_guard(), "Perform a step on the emulator (GIL released)") + + .def( + "screen_buffer", + [](NES::Emulator& emu) -> py::array_t { + const int HEIGHT = NES::Emulator::HEIGHT; + const int WIDTH = NES::Emulator::WIDTH; + + #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + // On little-endian systems: BGRx -> RGB + return py::array_t( + {HEIGHT, WIDTH, 3}, // shape (3 channels) + {WIDTH * 4, 4, -1}, // negative stride to reverse BGR->RGB + reinterpret_cast(emu.get_screen_buffer()) + 2, // start at B + py::capsule(emu.get_screen_buffer(), [](void*) {}) // capsule with data pointer + ); + #else + // On big-endian systems: xRGB -> RGB + return py::array_t( + {HEIGHT, WIDTH, 3}, // shape (3 channels) + {WIDTH * 4, 4, 1}, // normal stride + reinterpret_cast(emu.get_screen_buffer()) + 1, // skip x + py::capsule(emu.get_screen_buffer(), [](void*) {}) // capsule with data pointer + ); + #endif + }, + "Get the screen buffer as a HEIGHT x WIDTH x 3 numpy.ndarray in RGB format" + ) + + .def( + "controller", + [](NES::Emulator& emu, int port) -> py::array_t { + // Create a view of the controller buffer + return py::array_t( + {1}, // shape (1 controller) + {1}, // stride (1 byte per controller) + reinterpret_cast(emu.get_controller(port)), // pointer to data + py::capsule(emu.get_controller(port), [](void*) {}) // capsule with data pointer + ); + }, + py::arg("port"), + "Get the controller buffer as numpy.ndarray" + ) + + .def( + "memory_buffer", + [](NES::Emulator& emu) -> py::array_t { + // Create a view of the RAM buffer (0x800 bytes) + return py::array_t( + {0x800}, // shape (2048 bytes) + {1}, // stride (1 byte) + reinterpret_cast(emu.get_memory_buffer()), // pointer to data + py::capsule(emu.get_memory_buffer(), [](void*) {}) // capsule with data pointer + ); + }, + "Get the memory buffer as numpy.ndarray" + ) + + .def( + "dump_state", + [](NES::Emulator& emu) -> py::array_t { + // Create a copy of the state data + auto* core = new NES::Core; + memset(core, 0, sizeof(NES::Core)); + emu.snapshot(core); + + return py::array_t( + {sizeof(NES::Core)}, + {1}, + reinterpret_cast(core), + py::capsule(core, [](void* p) { delete static_cast(p); }) + ); + }, + "Dump the current state to bytes" + ) + + .def( + "load_state", + [](NES::Emulator& emu, const py::array_t& state) { + emu.restore(reinterpret_cast(state.request().ptr)); + emu.ppu_step(); + }, + py::arg("state"), + "Load state from bytes" + ) + ; + + // VectorEmulator - parallel NES emulation mirroring NESEmulator interface + py::class_(m, "VectorEmulator") + .def(py::init(), + py::arg("rom_path"), py::arg("num_envs"), + "Create multiple emulators for parallel stepping") + + .def_property_readonly("num_envs", &VectorEmulator::num_envs) + .def_property_readonly_static("height", [](py::object) { return NES::Emulator::HEIGHT; }) + .def_property_readonly_static("width", [](py::object) { return NES::Emulator::WIDTH; }) + + .def("reset", &VectorEmulator::reset, "Reset all emulators") + .def("reset_env", &VectorEmulator::reset_env, py::arg("idx"), "Reset a single emulator") + + .def("step", &VectorEmulator::step, + py::arg("actions"), + "Step all emulators 1 frame in parallel") + + .def("step_single", &VectorEmulator::step_single, + py::arg("idx"), py::arg("action"), + "Step a single emulator (synchronous, no threading)") + + .def("screen_buffer", + py::overload_cast<>(&VectorEmulator::screen_buffer), + "Get screen buffers for all emulators as list of zero-copy views") + + .def("screen_buffer", + py::overload_cast(&VectorEmulator::screen_buffer_single), + py::arg("idx"), + "Get screen buffer for single emulator as zero-copy view") + + .def("memory_buffer", + py::overload_cast<>(&VectorEmulator::memory_buffer), + "Get memory buffers for all emulators as list of zero-copy views") + + .def("memory_buffer", + py::overload_cast(&VectorEmulator::memory_buffer_single), + py::arg("idx"), + "Get memory buffer for single emulator as zero-copy view") + + .def("controller", &VectorEmulator::controller, + py::arg("idx"), py::arg("port") = 0, + "Get controller buffer for single emulator as zero-copy view") + + .def("dump_state", &VectorEmulator::dump_state, + py::arg("idx"), + "Dump state for single emulator") + + .def("load_state", &VectorEmulator::load_state, + py::arg("idx"), py::arg("state"), + "Load state for single emulator") + + .def("configure_ram_reads", &VectorEmulator::configure_ram_reads, + py::arg("specs"), + R"doc( +Configure RAM addresses to read after each step. + +Args: + specs: List of (address, size, type) tuples where: + - address: RAM address (0x0000-0x07FF) + - size: Number of bytes to read (1-6) + - type: 0=INT (single byte), 1=BCD (multiple digits) + +Example: + emulator.configure_ram_reads([ + (0x07DE, 6, 1), # score (6 BCD digits) + (0x07F8, 3, 1), # time (3 BCD digits) + (0x07ED, 2, 1), # coins (2 BCD digits) + (0x075A, 1, 0), # life (1 byte int) + ]) +)doc") + + .def("ram_values", &VectorEmulator::ram_values, + "Get RAM values from last step as numpy array, shape: (num_envs, num_specs)") + + .def("get_timing_stats", &VectorEmulator::get_timing_stats, + R"doc( +Get timing stats for step() breakdown and reset counters. + +Returns dict with: + - calls: Number of step() calls since last reset + - setup_ms: Time to set actions + signal workers (ms) + - wait_ms: Time waiting for workers to complete (ms) + - idle_ms: Time to mark workers idle (ms) + - ram_ms: Time to read RAM values (ms) + - total_ms: Total C++ time in step() (ms) +)doc") + + .def("get_worker_timing_stats", &VectorEmulator::get_worker_timing_stats, + R"doc( +Get per-worker timing stats and reset counters. -// un-define the macro -#undef EXP +Returns list of dicts, one per worker: + - worker: Worker index + - core: CPU core this worker is pinned to + - calls: Number of step calls + - wait_ms: Time spent waiting for work (ms) + - step_ms: Time spent stepping emulator (ms) + - step_avg_us: Average step time per call (microseconds) +)doc") + ; +}; diff --git a/nes_py/nes/src/ppu.cpp b/nes_py/nes/src/ppu.cpp index d5a96ed9..18772b8f 100644 --- a/nes_py/nes/src/ppu.cpp +++ b/nes_py/nes/src/ppu.cpp @@ -34,7 +34,7 @@ void PPU::reset() { scanline_sprites.resize(0); } -void PPU::cycle(PictureBus& bus) { +void PPU::cycle(PictureBus& bus, NESFrameBufferT* const screen) { switch (pipeline_state) { case PRE_RENDER: { if (cycles == 1) @@ -175,7 +175,7 @@ void PPU::cycle(PictureBus& bus) { else if (!bgOpaque && !sprOpaque) paletteAddr = 0; // lookup the pixel in the palette and write it to the screen - screen[y][x] = PALETTE[bus.read_palette(paletteAddr)]; + (*screen)[y][x] = PALETTE[bus.read_palette(paletteAddr)]; } else if (cycles == SCANLINE_VISIBLE_DOTS + 1 && is_showing_background) { //Shamelessly copied from nesdev wiki diff --git a/nes_py/nes_env.py b/nes_py/nes_env.py index 35333f5a..2fba63d9 100644 --- a/nes_py/nes_env.py +++ b/nes_py/nes_env.py @@ -1,200 +1,134 @@ """A CTypes interface to the C++ NES environment.""" -import ctypes -import glob import itertools -import os -import sys -import gym -from gym.spaces import Box -from gym.spaces import Discrete +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union +from typing import ClassVar +from typing import Optional +from typing import NamedTuple +from typing import SupportsFloat +from dataclasses import dataclass + import numpy as np -from ._rom import ROM -from ._image_viewer import ImageViewer - - -# the path to the directory this file is in -_MODULE_PATH = os.path.dirname(__file__) -# the pattern to find the C++ shared object library -_SO_PATH = 'lib_nes_env*' -# the absolute path to the C++ shared object library -_LIB_PATH = os.path.join(_MODULE_PATH, _SO_PATH) -# load the library from the shared object file -try: - _LIB = ctypes.cdll.LoadLibrary(glob.glob(_LIB_PATH)[0]) -except IndexError: - raise OSError('missing static lib_nes_env*.so library!') - - -# setup the argument and return types for Width -_LIB.Width.argtypes = None -_LIB.Width.restype = ctypes.c_uint -# setup the argument and return types for Height -_LIB.Height.argtypes = None -_LIB.Height.restype = ctypes.c_uint -# setup the argument and return types for Initialize -_LIB.Initialize.argtypes = [ctypes.c_wchar_p] -_LIB.Initialize.restype = ctypes.c_void_p -# setup the argument and return types for Controller -_LIB.Controller.argtypes = [ctypes.c_void_p, ctypes.c_uint] -_LIB.Controller.restype = ctypes.c_void_p -# setup the argument and return types for Screen -_LIB.Screen.argtypes = [ctypes.c_void_p] -_LIB.Screen.restype = ctypes.c_void_p -# setup the argument and return types for GetMemoryBuffer -_LIB.Memory.argtypes = [ctypes.c_void_p] -_LIB.Memory.restype = ctypes.c_void_p -# setup the argument and return types for Reset -_LIB.Reset.argtypes = [ctypes.c_void_p] -_LIB.Reset.restype = None -# setup the argument and return types for Step -_LIB.Step.argtypes = [ctypes.c_void_p] -_LIB.Step.restype = None -# setup the argument and return types for Backup -_LIB.Backup.argtypes = [ctypes.c_void_p] -_LIB.Backup.restype = None -# setup the argument and return types for Restore -_LIB.Restore.argtypes = [ctypes.c_void_p] -_LIB.Restore.restype = None -# setup the argument and return types for Close -_LIB.Close.argtypes = [ctypes.c_void_p] -_LIB.Close.restype = None - - -# height in pixels of the NES screen -SCREEN_HEIGHT = _LIB.Height() -# width in pixels of the NES screen -SCREEN_WIDTH = _LIB.Width() -# shape of the screen as 24-bit RGB (standard for NumPy) -SCREEN_SHAPE_24_BIT = SCREEN_HEIGHT, SCREEN_WIDTH, 3 -# shape of the screen as 32-bit RGB (C++ memory arrangement) -SCREEN_SHAPE_32_BIT = SCREEN_HEIGHT, SCREEN_WIDTH, 4 -# create a type for the screen tensor matrix from C++ -SCREEN_TENSOR = ctypes.c_byte * int(np.prod(SCREEN_SHAPE_32_BIT)) - - -# create a type for the RAM vector from C++ -RAM_VECTOR = ctypes.c_byte * 0x800 - - -# create a type for the controller buffers from C++ -CONTROLLER_VECTOR = ctypes.c_byte * 1 - - -class NESEnv(gym.Env): - """An NES environment based on the LaiNES emulator.""" +import gymnasium as gym +import lz4.block as lz4 +from gymnasium.spaces import Box +from gymnasium.spaces import Discrete - # relevant meta-data about the environment - metadata = { - 'render.modes': ['rgb_array', 'human'], - 'video.frames_per_second': 60 - } +from nes_py._rom import ROM +from nes_py.emulator import NESEmulator +from nes_py._image_viewer import ImageViewer - # the legal range for rewards for this environment - reward_range = (-float('inf'), float('inf')) - # observation space for the environment is static across all instances - observation_space = Box( - low=0, - high=255, - shape=SCREEN_SHAPE_24_BIT, - dtype=np.uint8 - ) - # action space is a bitmap of button press values for the 8 NES buttons - action_space = Discrete(256) +class NESGameCallbacks: + def _will_reset(self): + """Handle any RAM hacking after a reset occurs.""" + pass - def __init__(self, rom_path): - """ - Create a new NES environment. + def _did_reset(self): + """Handle any RAM hacking after a reset occurs.""" + pass - Args: - rom_path (str): the path to the ROM for the environment + def _will_restore(self): + """Handle any RAM hacking after a restore occurs.""" + pass - Returns: - None + def _did_restore(self): + """Handle any RAM hacking after a restore occurs.""" + pass - """ - # create a ROM file from the ROM path - rom = ROM(rom_path) + def _will_step(self): + """Handle any RAM hacking after a step occurs.""" + pass + + def _did_step(self, done: bool): + """Handle any RAM hacking after a step occurs.""" + pass + + def _get_reward(self) -> float: + """Return the reward after a step occurs.""" + return 0 + + def _get_done(self) -> bool: + """Return True if the episode is over, False otherwise.""" + return False + + def _get_info(self) -> Dict[str, Any]: + """Return the info after a step occurs.""" + return {} + +@dataclass(init=False) +class NESEmulatorWrapper(NESGameCallbacks): + _rom_path: str + _emulator: NESEmulator + + height: int = NESEmulator.height + width: int = NESEmulator.width + + def __init__(self, rom_path: str): + NESEmulatorWrapper.check_rom_compatibility(ROM.from_path(rom_path)) + + self._rom_path = rom_path + self._emulator = NESEmulator(rom_path) + + @staticmethod + def check_rom_compatibility(rom: ROM): + """Check that the ROM is compatible with the NES environment.""" # check that there is PRG ROM if rom.prg_rom_size == 0: raise ValueError('ROM has no PRG-ROM banks.') + # ensure that there is no trainer if rom.has_trainer: raise ValueError('ROM has trainer. trainer is not supported.') + # try to read the PRG ROM and raise a value error if it fails _ = rom.prg_rom + # try to read the CHR ROM and raise a value error if it fails - _ = rom.chr_rom + _ = rom.chr_rom + # check the TV system if rom.is_pal: raise ValueError('ROM is PAL. PAL is not supported.') # check that the mapper is implemented elif rom.mapper not in {0, 1, 2, 3}: msg = 'ROM has an unsupported mapper number {}. please see https://github.com/Kautenja/nes-py/issues/28 for more information.' - raise ValueError(msg.format(rom.mapper)) - # create a dedicated random number generator for the environment - self.np_random = np.random.RandomState() - # store the ROM path - self._rom_path = rom_path - # initialize the C++ object for running the environment - self._env = _LIB.Initialize(self._rom_path) - # setup a placeholder for a 'human' render mode viewer - self.viewer = None - # setup a placeholder for a pointer to a backup state - self._has_backup = False - # setup a done flag - self.done = True - # setup the controllers, screen, and RAM buffers - self.controllers = [self._controller_buffer(port) for port in range(2)] - self.screen = self._screen_buffer() - self.ram = self._ram_buffer() - - def _screen_buffer(self): - """Setup the screen buffer from the C++ code.""" - # get the address of the screen - address = _LIB.Screen(self._env) - # create a buffer from the contents of the address location - buffer_ = ctypes.cast(address, ctypes.POINTER(SCREEN_TENSOR)).contents - # create a NumPy array from the buffer - screen = np.frombuffer(buffer_, dtype='uint8') - # reshape the screen from a column vector to a tensor - screen = screen.reshape(SCREEN_SHAPE_32_BIT) - # flip the bytes if the machine is little-endian (which it likely is) - if sys.byteorder == 'little': - # invert the little-endian BGRx channels to big-endian xRGB - screen = screen[:, :, ::-1] - # remove the 0th axis (padding from storing colors in 32 bit) - return screen[:, :, 1:] - - def _ram_buffer(self): - """Setup the RAM buffer from the C++ code.""" - # get the address of the RAM - address = _LIB.Memory(self._env) - # create a buffer from the contents of the address location - buffer_ = ctypes.cast(address, ctypes.POINTER(RAM_VECTOR)).contents - # create a NumPy array from the buffer - return np.frombuffer(buffer_, dtype='uint8') - - def _controller_buffer(self, port): - """ - Find the pointer to a controller and setup a NumPy buffer. - - Args: - port: the port of the controller to setup - - Returns: - a NumPy buffer with the controller's binary data - - """ - # get the address of the controller - address = _LIB.Controller(self._env, port) - # create a memory buffer using the ctypes pointer for this vector - buffer_ = ctypes.cast(address, ctypes.POINTER(CONTROLLER_VECTOR)).contents - # create a NumPy buffer from the binary data and return it - return np.frombuffer(buffer_, dtype='uint8') - - def _frame_advance(self, action): + raise ValueError(msg.format(rom.mapper)) + + + @property + def _screen_buffer(self) -> np.ndarray: + return self._emulator.screen_buffer() + + @property + def _memory_buffer(self) -> np.ndarray: + return self._emulator.memory_buffer() + + @property + def _controller_buffers(self) -> List[np.ndarray]: + return [self._emulator.controller(port) for port in range(2)] + + @property + def ram(self) -> np.ndarray: + return self._memory_buffer + + @property + def screen(self) -> bytes: + return self._screen_buffer + + def dump_state(self) -> np.ndarray: + return self._emulator.dump_state() + + def load_state(self, snapshot: np.ndarray): + self._will_restore() + self._emulator.load_state(snapshot) + self._did_restore() + + def frame_advance(self, action: Union[int, Tuple[int, int]]) -> None: """ Advance a frame in the emulator with an action. @@ -206,77 +140,105 @@ def _frame_advance(self, action): """ # set the action on the controller - self.controllers[0][:] = action + if isinstance(action, (int, np.integer)): + self._controller_buffers[0][:] = action + elif isinstance(action, tuple) and len(action) == 2: + self._controller_buffers[0][:] = action[0] + self._controller_buffers[1][:] = action[1] + else: + raise ValueError(f'Invalid action type or length: {type(action)}') + # perform a step on the emulator - _LIB.Step(self._env) + self._emulator.step() - def _backup(self): - """Backup the NES state in the emulator.""" - _LIB.Backup(self._env) - self._has_backup = True - def _restore(self): - """Restore the backup state into the NES emulator.""" - _LIB.Restore(self._env) - def _will_reset(self): - """Handle any RAM hacking after a reset occurs.""" - pass - def seed(self, seed=None): - """ - Set the seed for this environment's random number generator. - Returns: - list: Returns the list of seeds used in this env's random - number generators. The first value in the list should be the - "main" seed, or the value which a reproducer should pass to - 'seed'. Often, the main seed equals the provided 'seed', but - this won't be true if seed=None, for example. +class StepResult(NamedTuple): + observation: np.ndarray + reward: float + terminated: bool + truncated: bool + info: Dict[str, Any] - """ - # if there is no seed, return an empty list - if seed is None: - return [] - # set the random number seed for the NumPy random number generator - self.np_random.seed(seed) - # return the list of seeds used by RNG(s) in the environment - return [seed] - - def reset(self, seed=None, options=None, return_info=None): + +@dataclass(init=False) +class NESEnv(NESEmulatorWrapper, gym.Env[np.ndarray, int]): + """An NES environment based on the LaiNES emulator.""" + _done: bool + _viewer: Optional[ImageViewer] + _np_random: np.random.RandomState + _snapshot: Optional[np.ndarray] + + # relevant meta-data about the environment + metadata: ClassVar[Dict[str, Any]] = { + 'render_modes': ['rgb_array', 'human'], + 'render_fps': 60 + } + + # the legal range for rewards for this environment + reward_range: ClassVar[Tuple[float, float]] = (-float('inf'), float('inf')) + + # observation space for the environment is static across all instances + observation_space: ClassVar[Box] = Box( + low=0, + high=255, + shape=(NESEmulator.height, NESEmulator.width, 3), + dtype=np.uint8 + ) + + # action space is a bitmap of button press values for the 8 NES buttons + action_space: ClassVar[Discrete] = Discrete(256) + + def __init__(self, rom_path: str): + super().__init__(rom_path) + + self._viewer = None + self._done = True + self._snapshot = None + + + def reset( + self, + *, + seed: Union[int, None] = None, + options: Union[Dict[str, Any], None] = None, + ) -> Tuple[np.ndarray, Dict[str, Any]]: """ Reset the state of the environment and returns an initial observation. Args: seed (int): an optional random number seed for the next episode options (any): unused - return_info (any): unused Returns: state (np.ndarray): next frame as a result of the given action """ # Set the seed. - self.seed(seed) + super().reset(seed=seed) + # call the before reset callback self._will_reset() + # reset the emulator - if self._has_backup: + if self._snapshot is not None: self._restore() else: - _LIB.Reset(self._env) + self._emulator.reset() + # call the after reset callback self._did_reset() + # set the done flag to false - self.done = False - # return the screen from the emulator - return self.screen + self._done = False - def _did_reset(self): - """Handle any RAM hacking after a reset occurs.""" - pass + # return the _screen_buffer from the emulator + return self._screen_buffer, self._get_info() - def step(self, action): + + def step(self, action: int) -> Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]: """ Run one frame of the NES and return the relevant observation data. @@ -288,69 +250,54 @@ def step(self, action): - state (np.ndarray): next frame as a result of the given action - reward (float) : amount of reward returned after given action - done (boolean): whether the episode has ended + - truncated (boolean): whether the episode has been truncated - info (dict): contains auxiliary diagnostic information """ # if the environment is done, raise an error - if self.done: + if self._done: raise ValueError('cannot step in a done environment! call `reset`') - # set the action on the controller - self.controllers[0][:] = action - # pass the action to the emulator as an unsigned byte - _LIB.Step(self._env) + + self.frame_advance(action) + # get the reward for this step reward = float(self._get_reward()) + reward = min(max(reward, self.reward_range[0]), self.reward_range[1]) + # get the done flag for this step - self.done = bool(self._get_done()) + self._done = bool(self._get_done()) + # get the info for this step info = self._get_info() - # call the after step callback - self._did_step(self.done) - # bound the reward in [min, max] - if reward < self.reward_range[0]: - reward = self.reward_range[0] - elif reward > self.reward_range[1]: - reward = self.reward_range[1] - # return the screen from the emulator and other relevant data - return self.screen, reward, self.done, info - - def _get_reward(self): - """Return the reward after a step occurs.""" - return 0 - - def _get_done(self): - """Return True if the episode is over, False otherwise.""" - return False - - def _get_info(self): - """Return the info after a step occurs.""" - return {} - def _did_step(self, done): - """ - Handle any RAM hacking after a step occurs. + # call the after step callback + self._did_step(self._done) - Args: - done (bool): whether the done flag is set to true + # return the _screen_buffer from the emulator and other relevant data + return StepResult( + observation=self._screen_buffer, + reward=reward, + terminated=self._done, + truncated=False, + info=info + ) - Returns: - None - - """ - pass def close(self): """Close the environment.""" # make sure the environment hasn't already been closed - if self._env is None: + if self._emulator is None: raise ValueError('env has already been closed.') + # purge the environment from C++ memory - _LIB.Close(self._env) + del self._emulator + # deallocate the object locally - self._env = None + self._emulator = None + # if there is an image viewer open, delete it - if self.viewer is not None: - self.viewer.close() + if self._viewer is not None: + self._viewer.close() def render(self, mode='human'): """ @@ -368,7 +315,7 @@ def render(self, mode='human'): """ if mode == 'human': # if the viewer isn't setup, import it and create one - if self.viewer is None: + if self._viewer is None: # get the caption for the ImageViewer if self.spec is None: # if there is no spec, just use the .nes filename @@ -377,15 +324,15 @@ def render(self, mode='human'): # set the caption to the OpenAI Gym id caption = self.spec.id # create the ImageViewer to display frames - self.viewer = ImageViewer( + self._viewer = ImageViewer( caption=caption, - height=SCREEN_HEIGHT, - width=SCREEN_WIDTH, + height=self._emulator.height, + width=self._emulator.width, ) - # show the screen on the image viewer - self.viewer.show(self.screen) + # show the _screen_buffer on the image viewer + self._viewer.show(self._screen_buffer) elif mode == 'rgb_array': - return self.screen + return self._screen_buffer else: # unpack the modes as comma delineated strings ('a', 'b', ...) render_modes = [repr(x) for x in self.metadata['render.modes']] @@ -424,6 +371,15 @@ def get_keys_to_action(self): def get_action_meanings(self): """Return a list of actions meanings.""" return ['NOOP'] + + def _backup(self): + self._snapshot = self.dump_state() + + def _restore(self): + if self._snapshot is None: + raise ValueError('no snapshot to restore') + + self.load_state(self._snapshot) # explicitly define the outward facing API of this module diff --git a/nes_py/tests/test_multiple_makes.py b/nes_py/tests/test_multiple_makes.py index 8764e48d..a1429858 100755 --- a/nes_py/tests/test_multiple_makes.py +++ b/nes_py/tests/test_multiple_makes.py @@ -2,9 +2,10 @@ from multiprocessing import Process from threading import Thread from unittest import TestCase -from .rom_file_abs_path import rom_file_abs_path from nes_py.nes_env import NESEnv +from rom_file_abs_path import rom_file_abs_path + def play(steps): """ @@ -26,7 +27,9 @@ def play(steps): if done: _ = env.reset() action = env.action_space.sample() - _, _, done, _ = env.step(action) + _, _, terminated, truncated, _ = env.step(action) + done = terminated or truncated + # close the environment env.close() @@ -84,4 +87,5 @@ def test(self): if dones[idx]: _ = envs[idx].reset() action = envs[idx].action_space.sample() - _, _, dones[idx], _ = envs[idx].step(action) + _, _, terminated, truncated, _ = envs[idx].step(action) + dones[idx] = terminated or truncated diff --git a/nes_py/tests/test_nes_env.py b/nes_py/tests/test_nes_env.py index 3e24c6e9..dae55406 100644 --- a/nes_py/tests/test_nes_env.py +++ b/nes_py/tests/test_nes_env.py @@ -1,9 +1,11 @@ """Test cases for the NESEnv class.""" from unittest import TestCase -import gym + import numpy as np -from .rom_file_abs_path import rom_file_abs_path +import gymnasium as gym + from nes_py.nes_env import NESEnv +from rom_file_abs_path import rom_file_abs_path class ShouldRaiseTypeErrorOnInvalidROMPathType(TestCase): @@ -57,9 +59,9 @@ def test(self): for _ in range(90): env.step(8) env.step(0) - self.assertEqual(129, env.ram[0x0776]) - env.ram[0x0776] = 0 - self.assertEqual(0, env.ram[0x0776]) + self.assertEqual(129, env._memory_buffer[0x0776]) + env._memory_buffer[0x0776] = 0 + self.assertEqual(0, env._memory_buffer[0x0776]) env.close() @@ -79,21 +81,26 @@ def test(self): for _ in range(500): if done: # reset the environment and check the output value - state = env.reset() - self.assertIsInstance(state, np.ndarray) + observation, _ = env.reset() + self.assertIsInstance(observation, np.ndarray) + # sample a random action and check it action = env.action_space.sample() - self.assertIsInstance(action, int) + self.assertIsInstance(action, (int, np.integer)) + # take a step and check the outputs output = env.step(action) self.assertIsInstance(output, tuple) - self.assertEqual(4, len(output)) + self.assertEqual(5, len(output)) + # check each output - state, reward, done, info = output - self.assertIsInstance(state, np.ndarray) + observation, reward, terminated, truncated, info = output + self.assertIsInstance(observation, np.ndarray) self.assertIsInstance(reward, float) - self.assertIsInstance(done, bool) + self.assertIsInstance(terminated, bool) + self.assertIsInstance(truncated, bool) self.assertIsInstance(info, dict) + # check the render output render = env.render('rgb_array') self.assertIsInstance(render, np.ndarray) @@ -110,7 +117,7 @@ def test(self): if done: state = env.reset() done = False - state, _, done, _ = env.step(0) + state, _, done, _, _ = env.step(0) backup = state.copy() @@ -120,9 +127,10 @@ def test(self): if done: state = env.reset() done = False - state, _, done, _ = env.step(0) + _, _, terminated, truncated, _ = env.step(0) + done = terminated or truncated self.assertFalse(np.array_equal(backup, state)) env._restore() - self.assertTrue(np.array_equal(backup, env.screen)) + self.assertTrue(np.array_equal(backup, env._screen_buffer)) env.close() diff --git a/nes_py/tests/test_rom.py b/nes_py/tests/test_rom.py index 6015fca5..1ae81397 100644 --- a/nes_py/tests/test_rom.py +++ b/nes_py/tests/test_rom.py @@ -5,8 +5,9 @@ """ from unittest import TestCase -from .rom_file_abs_path import rom_file_abs_path + from nes_py._rom import ROM +from rom_file_abs_path import rom_file_abs_path class ShouldNotCreateInstanceOfROMWithoutPath(TestCase): @@ -16,14 +17,14 @@ def test(self): class ShouldNotCreateInstanceOfROMWithInvaldPath(TestCase): def test(self): - self.assertRaises(TypeError, lambda: ROM(5)) - self.assertRaises(ValueError, lambda: ROM('not a path')) + self.assertRaises(TypeError, lambda: ROM.from_path(5)) + self.assertRaises(ValueError, lambda: ROM.from_path('not a path')) class ShouldNotCreateInstanceOfROMWithInvaldROMFile(TestCase): def test(self): empty = rom_file_abs_path('empty.nes') - self.assertRaises(ValueError, lambda: ROM(empty)) + self.assertRaises(ValueError, lambda: ROM.from_path(empty)) # @@ -91,7 +92,7 @@ class ShouldReadROMHeaderTestCase(object): def setUp(self): """Perform setup before each test.""" rom_path = rom_file_abs_path(self.rom_name) - self.rom = ROM(rom_path) + self.rom = ROM.from_path(rom_path) def test_header_length(self): """Check the length of the header.""" diff --git a/nes_py/wrappers/__init__.py b/nes_py/wrappers/__init__.py index 896ed203..3a866ddb 100755 --- a/nes_py/wrappers/__init__.py +++ b/nes_py/wrappers/__init__.py @@ -1,5 +1,5 @@ """Wrappers for altering the functionality of the game.""" -from .joypad_space import JoypadSpace +from nes_py.wrappers.joypad_space import JoypadSpace # explicitly define the outward facing API of this package diff --git a/nes_py/wrappers/joypad_space.py b/nes_py/wrappers/joypad_space.py index 893ea6fd..44e62149 100644 --- a/nes_py/wrappers/joypad_space.py +++ b/nes_py/wrappers/joypad_space.py @@ -1,14 +1,24 @@ """An environment wrapper to convert binary to discrete action space.""" -import gym -from gym import Env -from gym import Wrapper - - -class JoypadSpace(Wrapper): +from typing import Any +from typing import Dict +from typing import ClassVar +from typing import Optional +from typing import SupportsFloat + +import gymnasium as gym +from gymnasium import Env +from gymnasium import Wrapper +from gymnasium.core import ObsType +from gymnasium.core import ActType +from gymnasium.core import WrapperObsType +from gymnasium.core import WrapperActType + + +class JoypadSpace(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]): """An environment wrapper to convert binary to discrete action space.""" # a mapping of buttons to binary values - _button_map = { + _button_map: ClassVar[Dict[str, int]] = { 'right': 0b10000000, 'left': 0b01000000, 'down': 0b00100000, @@ -55,7 +65,10 @@ def __init__(self, env: Env, actions: list): self._action_map[action] = byte_action self._action_meanings[action] = ' '.join(button_list) - def step(self, action): + def step( + self, + action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: """ Take a step using the given action. @@ -73,9 +86,14 @@ def step(self, action): # take the step and record the output return self.env.step(self._action_map[action]) - def reset(self): + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None + ) -> tuple[WrapperObsType, dict[str, Any]]: """Reset the environment and return the initial observation.""" - return self.env.reset() + return self.env.reset(seed=seed, options=options) def get_keys_to_action(self): """Return the dictionary of keyboard keys to actions.""" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c3abf28b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[build-system] +requires = [ + "wheel", + "setuptools>=45.0.0", + "pybind11>=2.10.0", + "numpy<3" +] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 371f5e80..bc29807b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,10 @@ -gym>=0.17.2 -numpy>=1.18.5 -pyglet<=1.5.21,>=1.4.0 +build +wheel +pybind11>=2.10.0 +setuptools>=45.0.0 +gymnasium==1.0.0 +numpy<3 +pyglet<=2.0.1,>=1.4.0 tqdm>=4.48.2 twine>=1.11.0 +lz4==4.3.3 \ No newline at end of file diff --git a/backup_restore.py b/scripts/backup_restore.py similarity index 81% rename from backup_restore.py rename to scripts/backup_restore.py index 839dceee..d58d5fe1 100644 --- a/backup_restore.py +++ b/scripts/backup_restore.py @@ -10,10 +10,10 @@ state = env.reset() done = False else: - state, reward, done, info = env.step(env.action_space.sample()) + state, reward, done, _, info = env.step(env.action_space.sample()) if (i + 1) % 12: env._backup() if (i + 1) % 27: env._restore() except KeyboardInterrupt: - pass + pass \ No newline at end of file diff --git a/scripts/run.py b/scripts/run.py index 90abfae9..cbca2e0c 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -10,6 +10,6 @@ state = env.reset() done = False else: - state, reward, done, info = env.step(env.action_space.sample()) + state, reward, done, _, info = env.step(env.action_space.sample()) except KeyboardInterrupt: pass diff --git a/speedtest.py b/scripts/speedtest.py similarity index 77% rename from speedtest.py rename to scripts/speedtest.py index 90abfae9..cbca2e0c 100644 --- a/speedtest.py +++ b/scripts/speedtest.py @@ -10,6 +10,6 @@ state = env.reset() done = False else: - state, reward, done, info = env.step(env.action_space.sample()) + state, reward, done, _, info = env.step(env.action_space.sample()) except KeyboardInterrupt: pass diff --git a/setup.py b/setup.py index d50a5f74..9d099ff8 100644 --- a/setup.py +++ b/setup.py @@ -1,82 +1,138 @@ """The setup script for installing and distributing the nes-py package.""" import os +import subprocess from glob import glob -from setuptools import setup, find_packages, Extension - - -# set the compiler for the C++ framework -os.environ['CC'] = 'g++' -os.environ['CCX'] = 'g++' - - -# read the contents from the README file -with open('README.md') as README_file: - README = README_file.read() - - -# The prefix name for the .so library to build. It will follow the format -# lib_nes_env.*.so where the * changes depending on the build system -LIB_NAME = 'nes_py.lib_nes_env' -# The source files for building the extension. Globs locate all the cpp files -# used by the LaiNES subproject. MANIFEST.in has to include the blanket -# "cpp" directory to ensure that the .inc file gets included too -SOURCES = glob('nes_py/nes/src/*.cpp') + glob('nes_py/nes/src/mappers/*.cpp') -# The directory pointing to header files used by the LaiNES cpp files. -# This directory has to be included using MANIFEST.in too to include the -# headers with sdist -INCLUDE_DIRS = ['nes_py/nes/include'] -# Build arguments to pass to the compiler -EXTRA_COMPILE_ARGS = ['-std=c++1y', '-pipe', '-O3'] -# The official extension using the name, source, headers, and build args -LIB_NES_ENV = Extension(LIB_NAME, - sources=SOURCES, - include_dirs=INCLUDE_DIRS, - extra_compile_args=EXTRA_COMPILE_ARGS, -) - - -setup( - name='nes_py', - version='8.2.1', - description='An NES Emulator and OpenAI Gym interface', - long_description=README, - long_description_content_type='text/markdown', - keywords='NES Emulator OpenAI-Gym', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: POSIX :: Linux', - 'Operating System :: Microsoft :: Windows', - 'Programming Language :: C++', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Topic :: Games/Entertainment', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: System :: Emulators', - ], - url='https://github.com/Kautenja/nes-py', - author='Christian Kauten', - author_email='kautencreations@gmail.com', - license='MIT', - packages=find_packages(exclude=['tests', '*.tests', '*.tests.*']), - ext_modules=[LIB_NES_ENV], - zip_safe=False, - install_requires=[ - 'gym>=0.17.2', - 'numpy>=1.18.5', - 'pyglet<=1.5.21,>=1.4.0', - 'tqdm>=4.48.2', - ], - entry_points={ - 'console_scripts': [ - 'nes_py = nes_py.app.cli:main', +from typing import List +from pathlib import Path + +import pybind11 +from setuptools import find_packages +from setuptools import setup +from setuptools.command.build_ext import build_ext +from pybind11.setup_helpers import Pybind11Extension + + +class MakeBuilder(build_ext): + """Custom builder that uses Make.""" + + def build_extension(self, ext: Pybind11Extension) -> None: + """Build the extension using Make.""" + extension_path: str = os.path.dirname(ext.sources[0]) + make_path: str = os.path.join(extension_path, '..') + + try: + n_jobs = min(16, os.cpu_count()) + subprocess.check_call(['make', '-C', make_path, f'-j{n_jobs}']) + + built_lib: str = os.path.join(make_path, 'libemulator.so') + target_lib: str = self.get_ext_fullpath(ext.name) + + if os.path.exists(built_lib): + os.makedirs(os.path.dirname(target_lib), exist_ok=True) + import shutil + shutil.copy2(built_lib, target_lib) + else: + raise RuntimeError(f'Built library not found at {built_lib}') + + except subprocess.CalledProcessError as e: + raise RuntimeError(f'Error building extension: {e}') + + +def read_readme() -> str: + """Read the README.md file and return its content.""" + with open('README.md') as f: + return f.read() + + +def get_source_files() -> List[str]: + """Get all C++ source files needed for compilation.""" + return glob('nes_py/nes/src/*.cpp') + glob('nes_py/nes/src/mappers/*.cpp') + + +def configure_compiler() -> None: + """Configure the C++ compiler settings.""" + os.environ['CC'] = 'g++' + os.environ['CCX'] = 'g++' + + +def get_extension_modules() -> List[Pybind11Extension]: + """Create and return the extension modules configuration.""" + sources: List[str] = get_source_files() + + return [ + Pybind11Extension( + name='nes_py.emulator', + sources=sources, + include_dirs=[ + 'nes_py/nes/include', + 'nes_py/nes/src', + pybind11.get_include(), + pybind11.get_include(user=True) + ], + cxx_std=14, + extra_compile_args=['-O3', '-Wall', '-Wextra', '-pedantic'], + ), + ] + + +def get_requirements() -> List[str]: + """Get the requirements for the package.""" + with open('requirements.txt') as f: + return [line.strip() for line in f] + + +def main() -> None: + """Main setup configuration.""" + configure_compiler() + + setup( + name='nes_py', + version='9.3.1', + description='An NES Emulator with Gymnasium interface', + long_description=read_readme(), + long_description_content_type='text/markdown', + keywords='NES Emulator, Gymnasium', + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: POSIX :: Linux', + 'Operating System :: Microsoft :: Windows', + 'Programming Language :: C++', + 'Programming Language :: Python :: 3 :: Only', + *[f'Programming Language :: Python :: 3.{v}' for v in range(8, 13)], + 'Topic :: Games/Entertainment', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: System :: Emulators', ], - }, -) + url='https://github.com/Kautenja/nes-py', + author='Christian Kauten', + author_email='kautencreations@gmail.com', + license='MIT', + packages=find_packages(exclude=['tests', '*.tests', '*.tests.*']), + package_data={'nes_py': ['../requirements.txt']}, + ext_modules=get_extension_modules(), + zip_safe=False, + install_requires=get_requirements(), + extras_require={ + 'dev': [ + 'pytest>=7.0.0', + 'pytest-cov>=4.0.0', + 'black>=23.0.0', + 'mypy>=1.0.0', + ], + }, + entry_points={ + 'console_scripts': [ + 'nes_py=nes_py.app.cli:main', + ], + }, + python_requires='>=3.8', + cmdclass={'build_ext': MakeBuilder}, + ) + + +if __name__ == '__main__': + main()