Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions curriculum/nhl94.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@
"rf": "DefenseZone",
"num_timesteps": 20000000
},
{
"name": "Self-Play Offense Finetune",
"description": "Team 1 (learner) attacks from the attack zone against a frozen defensive opponent seeded from the previous curriculum checkpoint.",
"state": "PenguinsVsSenators.AttackZone",
"rf": "SelfPlayOffenseFinetune",
"selfplay": true,
"num_players": 2,
"opponent_snapshot_freq": 50000,
"opponent_pool_size": 5,
"num_timesteps": 10000000
},
{
"name": "Self-Play Defense Finetune",
"description": "Team 1 (learner) defends in the defensive zone against a frozen offensive opponent seeded from the previous finetune checkpoint.",
"state": "PenguinsVsSenators.DefenseZone",
"rf": "SelfPlayDefenseFinetune",
"selfplay": true,
"num_players": 2,
"opponent_snapshot_freq": 50000,
"opponent_pool_size": 5,
"num_timesteps": 10000000
},
{
"name": "Start State Test",
"description": "Evaluate the chained model from the normal game start state without additional training.",
Expand Down
249 changes: 151 additions & 98 deletions scripts/game_wrappers/nhl94_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
NHL94 Observation wrapper
"""

import datetime
import random
import copy
from collections import deque
from datetime import datetime
from typing import Optional
import numpy as np
import gymnasium as gym
from gymnasium import spaces
Expand All @@ -16,6 +15,11 @@
from game_wrappers.nhl94_gamestate import NHL94GameState


def _make_action_side_state():
"""Return a fresh per-side action-processing state dict."""
return {"b_pressed": False, "c_pressed": False, "slapshot_frames": 0}


class NHL94Observation2PEnv(gym.Wrapper):
def __init__(self, env, args, num_players, rf_name):
gym.Wrapper.__init__(self, env)
Expand Down Expand Up @@ -59,10 +63,6 @@ def __init__(self, env, args, num_players, rf_name):
else:
self.observation_space = spaces.Box(low, high, dtype=np.float32)

#self.action_space = 12 * [0]

self.prev_state = None

self.target_xy = [-1, -1]
random.seed(datetime.now().timestamp())

Expand All @@ -73,89 +73,100 @@ def __init__(self, env, args, num_players, rf_name):
self.ai_sys = NHL94AISystem(args, env, None)

self.ram_inited = False
self.b_button_pressed = False
self.c_button_pressed = False

self.slapshot_frames_held = 0 # 0 means not in slapshot mode
self.SLAPSHOT_HOLD_FRAMES = 60 # Number of frames to hold C for slapshot

def _get_scalar_state_array(self):
return np.asarray(self.state, dtype=np.float32)

def _reset_frame_buffer(self):
if not self.uses_sequence_obs:
return

self.frame_buffer.clear()
current_state = self._get_scalar_state_array()
for _ in range(self.frame_stack_size):
self.frame_buffer.append(current_state.copy())

def _get_obs(self, image_obs=None):
if self.nn == 'CombinedPolicy':
return {
'image': image_obs,
'scalar': self.state
}
if self.uses_sequence_obs:
return np.array(self.frame_buffer, dtype=np.float32, copy=True)
return self.state

def reset(self, **kwargs):
state, info = self.env.reset(**kwargs)
# Per-side action-processing state (learner = team 1, opponent = team 2)
self.action_state = {
"learner": _make_action_side_state(),
"opponent": _make_action_side_state(),
}

self.state = tuple([0] * self.NUM_PARAMS)
self._reset_frame_buffer()

self.game_state = NHL94GameState(self.num_players_per_team)
self.ram_inited = False
self.b_button_pressed = False
self.c_button_pressed = False
self.SLAPSHOT_HOLD_FRAMES = 60 # Number of frames to hold C for slapshot

return self._get_obs(state), info
# Self-play fields
self.opponent_model_path: str = ""
self.opponent_model = None

def step(self, ac):
p2_ac = [0,0,0,0,0,0,0,0,0,0,0,0]
p1_zero = [0,0,0,0,0,0,0,0,0,0,0,0]
# ------------------------------------------------------------------
# Self-play public API
# ------------------------------------------------------------------

if self.prev_state != None and self.num_players == 2:
self.prev_state.Flip()
def set_opponent_model(self, path: str) -> None:
"""Load or swap the frozen opponent model from ``path``.

Pass an empty string to disable the frozen opponent.
"""
from stable_baselines3 import PPO # pylint: disable=import-outside-toplevel

self.opponent_model_path = path
if path:
self.opponent_model = PPO.load(path)
else:
self.opponent_model = None

def compute_opponent_action(self, obs: np.ndarray) -> np.ndarray:
"""Query the frozen opponent policy for an action.

The opponent receives the current (team-1-perspective) observation.
Returns a zero-filled action array when no opponent model is loaded.
"""
if self.opponent_model is None:
if hasattr(self.action_space, 'n'):
return np.zeros(1, dtype=np.int64)
return np.zeros(self.action_space.shape, dtype=self.action_space.dtype)
action, _ = self.opponent_model.predict(obs, deterministic=True)
return action

def combine_selfplay_actions(self, learner_ac, opponent_ac) -> np.ndarray:
"""Combine learner and opponent actions into a two-player action array."""
return np.concatenate([np.array(learner_ac), np.array(opponent_ac)])

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

@property
def selfplay_enabled(self) -> bool:
"""True when a frozen opponent model has been loaded."""
return self.opponent_model is not None

def _process_action(self, ac, side_state):
"""Apply button debounce and slapshot handling in-place for one side.

Returns ``(updated_ac, gamestate_ac)`` where ``gamestate_ac`` is the
6-element boolean action array consumed by the game-state tracker.
"""
gamestate_ac = [0] * 6

# Handle different action space types
if isinstance(ac, (list, np.ndarray)) and len(ac) == 12:
# FILTERED action space (12-button array)
# B button handling
if self.b_button_pressed and ac[GameConsts.INPUT_B] == 1:
# B button debounce
if side_state["b_pressed"] and ac[GameConsts.INPUT_B] == 1:
ac[GameConsts.INPUT_B] = 0
self.b_button_pressed = False
elif not self.b_button_pressed and ac[GameConsts.INPUT_B] == 1:
self.b_button_pressed = True
side_state["b_pressed"] = False
elif not side_state["b_pressed"] and ac[GameConsts.INPUT_B] == 1:
side_state["b_pressed"] = True
else:
self.b_button_pressed = False
side_state["b_pressed"] = False

# C button handling (slapshot)
# C button / slapshot
if ac[GameConsts.INPUT_MODE] == 1:
if self.slapshot_frames_held == 0:
self.slapshot_frames_held = 1
if side_state["slapshot_frames"] == 0:
side_state["slapshot_frames"] = 1
ac[GameConsts.INPUT_C] = 1
else:
self.slapshot_frames_held += 1
side_state["slapshot_frames"] += 1
ac[GameConsts.INPUT_C] = 1
if self.slapshot_frames_held >= self.SLAPSHOT_HOLD_FRAMES:
self.slapshot_frames_held = 0
if side_state["slapshot_frames"] >= self.SLAPSHOT_HOLD_FRAMES:
side_state["slapshot_frames"] = 0
ac[GameConsts.INPUT_C] = 0
else:
if self.c_button_pressed and ac[GameConsts.INPUT_C] == 1:
if side_state["c_pressed"] and ac[GameConsts.INPUT_C] == 1:
ac[GameConsts.INPUT_C] = 0
self.c_button_pressed = False
elif not self.c_button_pressed and ac[GameConsts.INPUT_C] == 1:
self.c_button_pressed = True
side_state["c_pressed"] = False
elif not side_state["c_pressed"] and ac[GameConsts.INPUT_C] == 1:
side_state["c_pressed"] = True
else:
self.c_button_pressed = False
self.slapshot_frames_held = 0
side_state["c_pressed"] = False
side_state["slapshot_frames"] = 0

gamestate_ac[0] = ac[GameConsts.INPUT_UP] == 1
gamestate_ac[1] = ac[GameConsts.INPUT_DOWN] == 1
Expand All @@ -165,37 +176,29 @@ def step(self, ac):
gamestate_ac[5] = ac[GameConsts.INPUT_C] == 1

elif isinstance(ac, (list, np.ndarray)) and len(ac) == 3:
# MULTI_DISCRETE action space (3-button array)
# Process MULTI_DISCRETE actions for NHL94 format:
# ac = [vertical, horizontal, action] where:
# 0 = no input
# 1 = first option (e.g., "UP")
# 2 = second option (e.g., "DOWN")

# Make a copy to avoid modifying original
processed_ac = list(ac).copy()

# B button handling (action button 1)
if processed_ac[2] == 1: # B pressed
if self.b_button_pressed:
processed_ac[2] = 0 # Release B if already pressed
self.b_button_pressed = False
# B button debounce
if processed_ac[2] == 1:
if side_state["b_pressed"]:
processed_ac[2] = 0
side_state["b_pressed"] = False
else:
self.b_button_pressed = True
side_state["b_pressed"] = True
else:
self.b_button_pressed = False
side_state["b_pressed"] = False

# C button handling (action button 2 - slapshot)
if processed_ac[2] == 2: # C pressed
if self.slapshot_frames_held == 0:
self.slapshot_frames_held = 1
# C slapshot
if processed_ac[2] == 2:
if side_state["slapshot_frames"] == 0:
side_state["slapshot_frames"] = 1
else:
self.slapshot_frames_held += 1
if self.slapshot_frames_held >= self.SLAPSHOT_HOLD_FRAMES:
processed_ac[2] = 0 # Release C after hold duration
self.slapshot_frames_held = 0
side_state["slapshot_frames"] += 1
if side_state["slapshot_frames"] >= self.SLAPSHOT_HOLD_FRAMES:
processed_ac[2] = 0
side_state["slapshot_frames"] = 0
else:
self.slapshot_frames_held = 0
side_state["slapshot_frames"] = 0

ac = processed_ac

Expand All @@ -207,28 +210,78 @@ def step(self, ac):
gamestate_ac[5] = ac[2] == 2

else:
# Handle other cases or raise error
raise ValueError(f"Unsupported action format: {ac}")

return ac, gamestate_ac

def _get_scalar_state_array(self):
return np.asarray(self.state, dtype=np.float32)

def _reset_frame_buffer(self):
if not self.uses_sequence_obs:
return

self.frame_buffer.clear()
current_state = self._get_scalar_state_array()
for _ in range(self.frame_stack_size):
self.frame_buffer.append(current_state.copy())

def _get_obs(self, image_obs=None):
if self.nn == 'CombinedPolicy':
return {
'image': image_obs,
'scalar': self.state
}
if self.uses_sequence_obs:
return np.array(self.frame_buffer, dtype=np.float32, copy=True)
return self.state

def reset(self, **kwargs):
state, info = self.env.reset(**kwargs)

self.state = tuple([0] * self.NUM_PARAMS)
self._reset_frame_buffer()

self.game_state = NHL94GameState(self.num_players_per_team)
self.ram_inited = False
self.action_state = {
"learner": _make_action_side_state(),
"opponent": _make_action_side_state(),
}

return self._get_obs(state), info

def step(self, ac):
p2_ac = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# Process learner action with side-aware state
ac = list(ac) if isinstance(ac, np.ndarray) else ac
ac, gamestate_ac = self._process_action(ac, self.action_state["learner"])

# Reward functions might need to override input
#print(ac)
self.input_overide(ac)

ac2 = ac
if self.num_players == 2:
ac2 = np.concatenate([ac, np.array(p2_ac)])
if self.selfplay_enabled:
# Query frozen opponent and process its action independently
learner_obs = np.array(self.state, dtype=np.float32)
opp_raw = self.compute_opponent_action(learner_obs)
opp_raw = list(opp_raw) if isinstance(opp_raw, np.ndarray) else list(opp_raw)
# Ensure the opponent action has the same length as the learner action
if len(opp_raw) != len(ac):
opp_raw = (opp_raw + [0] * len(ac))[:len(ac)]
opp_ac, _ = self._process_action(opp_raw, self.action_state["opponent"])
p2_ac = opp_ac
ac2 = np.concatenate([np.array(ac), np.array(p2_ac)])
else:
ac2 = ac

ob, rew, terminated, truncated, info = self.env.step(ac2)

if not self.ram_inited:
self.init_function(self.env, self.env_name)
self.ram_inited = True

self.prev_state = copy.deepcopy(self.game_state)

self.game_state.BeginFrame(info, gamestate_ac)

# Calculate Reward and check if episode is done
Expand Down
Loading