From b521eadd97645ef21f39ea7b3a74974ef8eb5fcf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:57:25 +0000 Subject: [PATCH 1/2] Initial plan From eb8ee154bdf0efd9cef1f1723a43a21908c59316 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 13:14:04 +0000 Subject: [PATCH 2/2] Implement NHL94 self-play: finetune reward modes, side-aware actions, opponent snapshot rotation, curriculum phases Agent-Logs-Url: https://github.com/MatPoliquin/stable-retro-scripts/sessions/c2ccf5b5-8dfc-4051-879b-b79ad0786973 Co-authored-by: MatPoliquin <7024551+MatPoliquin@users.noreply.github.com> --- curriculum/nhl94.json | 22 +++ scripts/game_wrappers/nhl94_obs.py | 249 +++++++++++++++++------------ scripts/game_wrappers/nhl94_rf.py | 134 +++++++++++++++- scripts/train_curriculum.py | 7 +- scripts/train_live.py | 142 +++++++++++++++- 5 files changed, 452 insertions(+), 102 deletions(-) diff --git a/curriculum/nhl94.json b/curriculum/nhl94.json index 7d320ae..5200009 100644 --- a/curriculum/nhl94.json +++ b/curriculum/nhl94.json @@ -58,6 +58,28 @@ "rf": "DefenseZone", "num_timesteps": 20000000 }, + { + "name": "Self-Play Offense Finetune", + "description": "Team 1 (learner) attacks from the attack zone against a frozen defensive opponent seeded from the previous curriculum checkpoint.", + "state": "PenguinsVsSenators.AttackZone", + "rf": "SelfPlayOffenseFinetune", + "selfplay": true, + "num_players": 2, + "opponent_snapshot_freq": 50000, + "opponent_pool_size": 5, + "num_timesteps": 10000000 + }, + { + "name": "Self-Play Defense Finetune", + "description": "Team 1 (learner) defends in the defensive zone against a frozen offensive opponent seeded from the previous finetune checkpoint.", + "state": "PenguinsVsSenators.DefenseZone", + "rf": "SelfPlayDefenseFinetune", + "selfplay": true, + "num_players": 2, + "opponent_snapshot_freq": 50000, + "opponent_pool_size": 5, + "num_timesteps": 10000000 + }, { "name": "Start State Test", "description": "Evaluate the chained model from the normal game start state without additional training.", diff --git a/scripts/game_wrappers/nhl94_obs.py b/scripts/game_wrappers/nhl94_obs.py index f769009..c67a818 100644 --- a/scripts/game_wrappers/nhl94_obs.py +++ b/scripts/game_wrappers/nhl94_obs.py @@ -2,11 +2,10 @@ NHL94 Observation wrapper """ -import datetime import random -import copy from collections import deque from datetime import datetime +from typing import Optional import numpy as np import gymnasium as gym from gymnasium import spaces @@ -16,6 +15,11 @@ from game_wrappers.nhl94_gamestate import NHL94GameState +def _make_action_side_state(): + """Return a fresh per-side action-processing state dict.""" + return {"b_pressed": False, "c_pressed": False, "slapshot_frames": 0} + + class NHL94Observation2PEnv(gym.Wrapper): def __init__(self, env, args, num_players, rf_name): gym.Wrapper.__init__(self, env) @@ -59,10 +63,6 @@ def __init__(self, env, args, num_players, rf_name): else: self.observation_space = spaces.Box(low, high, dtype=np.float32) - #self.action_space = 12 * [0] - - self.prev_state = None - self.target_xy = [-1, -1] random.seed(datetime.now().timestamp()) @@ -73,89 +73,100 @@ def __init__(self, env, args, num_players, rf_name): self.ai_sys = NHL94AISystem(args, env, None) self.ram_inited = False - self.b_button_pressed = False - self.c_button_pressed = False - - self.slapshot_frames_held = 0 # 0 means not in slapshot mode - self.SLAPSHOT_HOLD_FRAMES = 60 # Number of frames to hold C for slapshot - - def _get_scalar_state_array(self): - return np.asarray(self.state, dtype=np.float32) - - def _reset_frame_buffer(self): - if not self.uses_sequence_obs: - return - - self.frame_buffer.clear() - current_state = self._get_scalar_state_array() - for _ in range(self.frame_stack_size): - self.frame_buffer.append(current_state.copy()) - - def _get_obs(self, image_obs=None): - if self.nn == 'CombinedPolicy': - return { - 'image': image_obs, - 'scalar': self.state - } - if self.uses_sequence_obs: - return np.array(self.frame_buffer, dtype=np.float32, copy=True) - return self.state - def reset(self, **kwargs): - state, info = self.env.reset(**kwargs) + # Per-side action-processing state (learner = team 1, opponent = team 2) + self.action_state = { + "learner": _make_action_side_state(), + "opponent": _make_action_side_state(), + } - self.state = tuple([0] * self.NUM_PARAMS) - self._reset_frame_buffer() - - self.game_state = NHL94GameState(self.num_players_per_team) - self.ram_inited = False - self.b_button_pressed = False - self.c_button_pressed = False + self.SLAPSHOT_HOLD_FRAMES = 60 # Number of frames to hold C for slapshot - return self._get_obs(state), info + # Self-play fields + self.opponent_model_path: str = "" + self.opponent_model = None - def step(self, ac): - p2_ac = [0,0,0,0,0,0,0,0,0,0,0,0] - p1_zero = [0,0,0,0,0,0,0,0,0,0,0,0] + # ------------------------------------------------------------------ + # Self-play public API + # ------------------------------------------------------------------ - if self.prev_state != None and self.num_players == 2: - self.prev_state.Flip() + def set_opponent_model(self, path: str) -> None: + """Load or swap the frozen opponent model from ``path``. + Pass an empty string to disable the frozen opponent. + """ + from stable_baselines3 import PPO # pylint: disable=import-outside-toplevel + self.opponent_model_path = path + if path: + self.opponent_model = PPO.load(path) + else: + self.opponent_model = None + + def compute_opponent_action(self, obs: np.ndarray) -> np.ndarray: + """Query the frozen opponent policy for an action. + + The opponent receives the current (team-1-perspective) observation. + Returns a zero-filled action array when no opponent model is loaded. + """ + if self.opponent_model is None: + if hasattr(self.action_space, 'n'): + return np.zeros(1, dtype=np.int64) + return np.zeros(self.action_space.shape, dtype=self.action_space.dtype) + action, _ = self.opponent_model.predict(obs, deterministic=True) + return action + + def combine_selfplay_actions(self, learner_ac, opponent_ac) -> np.ndarray: + """Combine learner and opponent actions into a two-player action array.""" + return np.concatenate([np.array(learner_ac), np.array(opponent_ac)]) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @property + def selfplay_enabled(self) -> bool: + """True when a frozen opponent model has been loaded.""" + return self.opponent_model is not None + + def _process_action(self, ac, side_state): + """Apply button debounce and slapshot handling in-place for one side. + + Returns ``(updated_ac, gamestate_ac)`` where ``gamestate_ac`` is the + 6-element boolean action array consumed by the game-state tracker. + """ gamestate_ac = [0] * 6 - # Handle different action space types if isinstance(ac, (list, np.ndarray)) and len(ac) == 12: - # FILTERED action space (12-button array) - # B button handling - if self.b_button_pressed and ac[GameConsts.INPUT_B] == 1: + # B button debounce + if side_state["b_pressed"] and ac[GameConsts.INPUT_B] == 1: ac[GameConsts.INPUT_B] = 0 - self.b_button_pressed = False - elif not self.b_button_pressed and ac[GameConsts.INPUT_B] == 1: - self.b_button_pressed = True + side_state["b_pressed"] = False + elif not side_state["b_pressed"] and ac[GameConsts.INPUT_B] == 1: + side_state["b_pressed"] = True else: - self.b_button_pressed = False + side_state["b_pressed"] = False - # C button handling (slapshot) + # C button / slapshot if ac[GameConsts.INPUT_MODE] == 1: - if self.slapshot_frames_held == 0: - self.slapshot_frames_held = 1 + if side_state["slapshot_frames"] == 0: + side_state["slapshot_frames"] = 1 ac[GameConsts.INPUT_C] = 1 else: - self.slapshot_frames_held += 1 + side_state["slapshot_frames"] += 1 ac[GameConsts.INPUT_C] = 1 - if self.slapshot_frames_held >= self.SLAPSHOT_HOLD_FRAMES: - self.slapshot_frames_held = 0 + if side_state["slapshot_frames"] >= self.SLAPSHOT_HOLD_FRAMES: + side_state["slapshot_frames"] = 0 ac[GameConsts.INPUT_C] = 0 else: - if self.c_button_pressed and ac[GameConsts.INPUT_C] == 1: + if side_state["c_pressed"] and ac[GameConsts.INPUT_C] == 1: ac[GameConsts.INPUT_C] = 0 - self.c_button_pressed = False - elif not self.c_button_pressed and ac[GameConsts.INPUT_C] == 1: - self.c_button_pressed = True + side_state["c_pressed"] = False + elif not side_state["c_pressed"] and ac[GameConsts.INPUT_C] == 1: + side_state["c_pressed"] = True else: - self.c_button_pressed = False - self.slapshot_frames_held = 0 + side_state["c_pressed"] = False + side_state["slapshot_frames"] = 0 gamestate_ac[0] = ac[GameConsts.INPUT_UP] == 1 gamestate_ac[1] = ac[GameConsts.INPUT_DOWN] == 1 @@ -165,37 +176,29 @@ def step(self, ac): gamestate_ac[5] = ac[GameConsts.INPUT_C] == 1 elif isinstance(ac, (list, np.ndarray)) and len(ac) == 3: - # MULTI_DISCRETE action space (3-button array) - # Process MULTI_DISCRETE actions for NHL94 format: - # ac = [vertical, horizontal, action] where: - # 0 = no input - # 1 = first option (e.g., "UP") - # 2 = second option (e.g., "DOWN") - - # Make a copy to avoid modifying original processed_ac = list(ac).copy() - # B button handling (action button 1) - if processed_ac[2] == 1: # B pressed - if self.b_button_pressed: - processed_ac[2] = 0 # Release B if already pressed - self.b_button_pressed = False + # B button debounce + if processed_ac[2] == 1: + if side_state["b_pressed"]: + processed_ac[2] = 0 + side_state["b_pressed"] = False else: - self.b_button_pressed = True + side_state["b_pressed"] = True else: - self.b_button_pressed = False + side_state["b_pressed"] = False - # C button handling (action button 2 - slapshot) - if processed_ac[2] == 2: # C pressed - if self.slapshot_frames_held == 0: - self.slapshot_frames_held = 1 + # C slapshot + if processed_ac[2] == 2: + if side_state["slapshot_frames"] == 0: + side_state["slapshot_frames"] = 1 else: - self.slapshot_frames_held += 1 - if self.slapshot_frames_held >= self.SLAPSHOT_HOLD_FRAMES: - processed_ac[2] = 0 # Release C after hold duration - self.slapshot_frames_held = 0 + side_state["slapshot_frames"] += 1 + if side_state["slapshot_frames"] >= self.SLAPSHOT_HOLD_FRAMES: + processed_ac[2] = 0 + side_state["slapshot_frames"] = 0 else: - self.slapshot_frames_held = 0 + side_state["slapshot_frames"] = 0 ac = processed_ac @@ -207,19 +210,71 @@ def step(self, ac): gamestate_ac[5] = ac[2] == 2 else: - # Handle other cases or raise error raise ValueError(f"Unsupported action format: {ac}") + return ac, gamestate_ac + def _get_scalar_state_array(self): + return np.asarray(self.state, dtype=np.float32) + def _reset_frame_buffer(self): + if not self.uses_sequence_obs: + return + + self.frame_buffer.clear() + current_state = self._get_scalar_state_array() + for _ in range(self.frame_stack_size): + self.frame_buffer.append(current_state.copy()) + + def _get_obs(self, image_obs=None): + if self.nn == 'CombinedPolicy': + return { + 'image': image_obs, + 'scalar': self.state + } + if self.uses_sequence_obs: + return np.array(self.frame_buffer, dtype=np.float32, copy=True) + return self.state + + def reset(self, **kwargs): + state, info = self.env.reset(**kwargs) + + self.state = tuple([0] * self.NUM_PARAMS) + self._reset_frame_buffer() + + self.game_state = NHL94GameState(self.num_players_per_team) + self.ram_inited = False + self.action_state = { + "learner": _make_action_side_state(), + "opponent": _make_action_side_state(), + } + + return self._get_obs(state), info + + def step(self, ac): + p2_ac = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + # Process learner action with side-aware state + ac = list(ac) if isinstance(ac, np.ndarray) else ac + ac, gamestate_ac = self._process_action(ac, self.action_state["learner"]) # Reward functions might need to override input - #print(ac) self.input_overide(ac) - ac2 = ac if self.num_players == 2: - ac2 = np.concatenate([ac, np.array(p2_ac)]) + if self.selfplay_enabled: + # Query frozen opponent and process its action independently + learner_obs = np.array(self.state, dtype=np.float32) + opp_raw = self.compute_opponent_action(learner_obs) + opp_raw = list(opp_raw) if isinstance(opp_raw, np.ndarray) else list(opp_raw) + # Ensure the opponent action has the same length as the learner action + if len(opp_raw) != len(ac): + opp_raw = (opp_raw + [0] * len(ac))[:len(ac)] + opp_ac, _ = self._process_action(opp_raw, self.action_state["opponent"]) + p2_ac = opp_ac + ac2 = np.concatenate([np.array(ac), np.array(p2_ac)]) + else: + ac2 = ac ob, rew, terminated, truncated, info = self.env.step(ac2) @@ -227,8 +282,6 @@ def step(self, ac): self.init_function(self.env, self.env_name) self.ram_inited = True - self.prev_state = copy.deepcopy(self.game_state) - self.game_state.BeginFrame(info, gamestate_ac) # Calculate Reward and check if episode is done diff --git a/scripts/game_wrappers/nhl94_rf.py b/scripts/game_wrappers/nhl94_rf.py index 7623531..cae2961 100644 --- a/scripts/game_wrappers/nhl94_rf.py +++ b/scripts/game_wrappers/nhl94_rf.py @@ -772,7 +772,7 @@ def rf_passing(state): return rew # ===================================================================== -# Self Play +# Self Play (legacy stub – kept for backward compatibility) # ===================================================================== def isdone_selfplay(state): t1 = state.team1 @@ -800,6 +800,136 @@ def rf_selfplay(state): # wrapper will negate if training Team-2 return base +# ===================================================================== +# Self Play – Offense Finetune +# Team 1 (learner) attacks; team 2 (frozen opponent) defends. +# Zero-sum terminal rewards from team-1 perspective. +# ===================================================================== + +# Consecutive frames team-2 must control the puck to count as a defensive clear +SELFPLAY_CONTROL_FRAMES = 30 + + +def init_selfplay_offense(env, env_name): + """Initialize attack-zone start positions for the offense finetune drill.""" + init_attackzone(env, env_name) + + +def isdone_selfplay_offense(state): + t1 = state.team1 + t2 = state.team2 + + # Success: team 1 scores + if t1.stats.score > t1.last_stats.score: + return True + + # Failure: puck leaves the attack zone + if state.puck.y < GameConsts.ATACKZONE_POS_Y: + return True + + # Failure: team 2 holds controlled possession for N consecutive frames + ctx = getattr(state, "_sp_offense_ctx", None) + if ctx is not None and ctx.get("t2_control_frames", 0) >= SELFPLAY_CONTROL_FRAMES: + return True + + # Timeout + if state.time < 100: + return True + + return False + + +def rf_selfplay_offense(state): + t1 = state.team1 + t2 = state.team2 + + ctx = getattr(state, "_sp_offense_ctx", None) + if ctx is None: + ctx = {"t2_control_frames": 0} + setattr(state, "_sp_offense_ctx", ctx) + + # Track consecutive frames of team-2 controlled possession + if t2.player_haspuck or t2.goalie_haspuck: + ctx["t2_control_frames"] = ctx.get("t2_control_frames", 0) + 1 + else: + ctx["t2_control_frames"] = 0 + + # Terminal outcomes only + if t1.stats.score > t1.last_stats.score: + return 1.0 + + if state.puck.y < GameConsts.ATACKZONE_POS_Y: + return -1.0 + + if ctx.get("t2_control_frames", 0) >= SELFPLAY_CONTROL_FRAMES: + return -1.0 + + return 0.0 + + +# ===================================================================== +# Self Play – Defense Finetune +# Team 1 (learner) defends; team 2 (frozen opponent) attacks. +# Zero-sum terminal rewards from team-1 perspective. +# ===================================================================== + +def init_selfplay_defense(env, env_name): + """Initialize defense-zone start positions for the defense finetune drill.""" + init_defensezone(env, env_name) + + +def isdone_selfplay_defense(state): + t1 = state.team1 + t2 = state.team2 + + # Failure: team 2 scores + if t2.stats.score > t2.last_stats.score: + return True + + # Success: puck cleared past defense zone into neutral/attack territory + if state.puck.y >= GameConsts.ATACKZONE_POS_Y: + return True + + # Success: team 1 holds controlled possession for N consecutive frames + ctx = getattr(state, "_sp_defense_ctx", None) + if ctx is not None and ctx.get("t1_control_frames", 0) >= SELFPLAY_CONTROL_FRAMES: + return True + + # Timeout + if state.time < 100: + return True + + return False + + +def rf_selfplay_defense(state): + t1 = state.team1 + t2 = state.team2 + + ctx = getattr(state, "_sp_defense_ctx", None) + if ctx is None: + ctx = {"t1_control_frames": 0} + setattr(state, "_sp_defense_ctx", ctx) + + # Track consecutive frames of team-1 controlled possession + if t1.player_haspuck or t1.goalie_haspuck: + ctx["t1_control_frames"] = ctx.get("t1_control_frames", 0) + 1 + else: + ctx["t1_control_frames"] = 0 + + # Terminal outcomes only + if t2.stats.score > t2.last_stats.score: + return -1.0 + + if state.puck.y >= GameConsts.ATACKZONE_POS_Y: + return 1.0 + + if ctx.get("t1_control_frames", 0) >= SELFPLAY_CONTROL_FRAMES: + return 1.0 + + return 0.0 + + # ===================================================================== # Register Functions # ===================================================================== @@ -813,6 +943,8 @@ def rf_selfplay(state): "Passing": (init_attackzone, rf_passing, isdone_passing, init_model_rel_dist_buttons, set_model_input_rel_dist_buttons, input_overide_no_shoot), "General": (init_general, rf_general, isdone_general, init_model_rel_dist_buttons, set_model_input_rel_dist_buttons, input_overide_empty), "SelfPlay": (init_selfplay, rf_selfplay, isdone_selfplay, init_model_invariant, set_model_input_invariant, input_overide_empty), + "SelfPlayOffenseFinetune": (init_selfplay_offense, rf_selfplay_offense, isdone_selfplay_offense, init_model_rel_dist_buttons, set_model_input_rel_dist_buttons, input_overide_empty), + "SelfPlayDefenseFinetune": (init_selfplay_defense, rf_selfplay_defense, isdone_selfplay_defense, init_model_rel_dist_buttons, set_model_input_rel_dist_buttons, input_overide_empty), } def register_functions(name: str) -> Tuple[Callable, Callable, Callable]: diff --git a/scripts/train_curriculum.py b/scripts/train_curriculum.py index 946a7b1..5f50316 100644 --- a/scripts/train_curriculum.py +++ b/scripts/train_curriculum.py @@ -21,7 +21,7 @@ "phase_type", "eval_episodes", } -PATH_KEYS = ("hyperparams", "load_p1_model", "output_basedir") +PATH_KEYS = ("hyperparams", "load_p1_model", "load_opponent_model", "output_basedir") def build_parser() -> argparse.ArgumentParser: @@ -82,6 +82,11 @@ def merge_phase_config( if previous_model_path: merged["load_p1_model"] = previous_model_path + # For self-play phases, seed the initial frozen opponent from the previous + # curriculum checkpoint when no explicit opponent path has been provided. + if merged.get("selfplay") and previous_model_path and not merged.get("load_opponent_model"): + merged["load_opponent_model"] = previous_model_path + return merged diff --git a/scripts/train_live.py b/scripts/train_live.py index 174ed45..2fa165c 100644 --- a/scripts/train_live.py +++ b/scripts/train_live.py @@ -13,6 +13,7 @@ import argparse import os +import random import sys import threading from collections import deque @@ -26,7 +27,7 @@ import pygame import pygame.freetype from stable_baselines3 import PPO -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.callbacks import BaseCallback, CallbackList from stable_baselines3.common.evaluation import evaluate_policy from common import com_print, create_output_dir, get_model_file_name, init_logger @@ -155,6 +156,100 @@ def _on_rollout_end(self) -> bool: return True +class OpponentSnapshotCallback(BaseCallback): + """Periodically snapshots the learner and rotates frozen opponents in all worker envs. + + Maintains a bounded checkpoint pool and samples opponents with a 40/40/20 mixture: + 40% latest snapshot, 40% random historical, 20% best-known checkpoint. + """ + + def __init__( + self, + snapshot_dir: str, + snapshot_freq: int, + pool_size: int, + best_model_path: str, + verbose: int = 0, + ) -> None: + super().__init__(verbose=verbose) + self.snapshot_dir = snapshot_dir + self.snapshot_freq = max(1, snapshot_freq) + self.pool_size = max(1, pool_size) + self.best_model_path = best_model_path + self._last_snapshot_step = 0 + self._checkpoint_pool: List[str] = [] + self._snapshot_counter = 0 + os.makedirs(snapshot_dir, exist_ok=True) + + def _save_snapshot(self) -> str: + """Save current model as a snapshot and return the .zip path.""" + self._snapshot_counter += 1 + base = os.path.join(self.snapshot_dir, f"opponent_snapshot_{self._snapshot_counter}") + self.model.save(base) + return ensure_zip_path(base) + + def _add_to_pool(self, zip_path: str) -> None: + """Add a snapshot to the pool, evicting the oldest entry if needed.""" + self._checkpoint_pool.append(zip_path) + if len(self._checkpoint_pool) > self.pool_size: + self._checkpoint_pool.pop(0) + + def _sample_opponent(self) -> Optional[str]: + """Sample an opponent checkpoint using a 40/40/20 weighted mixture. + + * 40 % – latest snapshot + * 40 % – random historical snapshot + * 20 % – best-known checkpoint (falls back to random if unavailable) + """ + if not self._checkpoint_pool: + return None + + best_zip = ensure_zip_path(self.best_model_path) + best_exists = os.path.exists(best_zip) + + roll = random.random() + if roll < 0.4: + return self._checkpoint_pool[-1] + elif roll < 0.8 or not best_exists: + return random.choice(self._checkpoint_pool) + else: + return best_zip + + def _update_env_opponents(self, zip_path: str) -> None: + """Push a new opponent checkpoint to all worker environments.""" + if not (zip_path and os.path.exists(zip_path)): + return + try: + self.training_env.env_method("set_opponent_model", zip_path) + except Exception as exc: # pylint: disable=broad-except + com_print(f"[OpponentSnapshotCallback] Failed to update opponents: {exc}") + + def _on_rollout_end(self) -> bool: + if self.num_timesteps - self._last_snapshot_step < self.snapshot_freq: + return True + + self._last_snapshot_step = self.num_timesteps + + snapshot_path = self._save_snapshot() + if os.path.exists(snapshot_path): + self._add_to_pool(snapshot_path) + if self.verbose: + com_print( + f"[OpponentSnapshotCallback] Saved snapshot at step " + f"{self.num_timesteps}: {snapshot_path}" + ) + + opponent_path = self._sample_opponent() + if opponent_path: + self._update_env_opponents(opponent_path) + if self.verbose: + com_print( + f"[OpponentSnapshotCallback] Rotated opponent to: {opponent_path}" + ) + + return True + + class LiveTrainingDisplay(threading.Thread): """Pygame window that replays the current best model and plots rewards.""" @@ -776,6 +871,16 @@ def __init__(self, args: argparse.Namespace, logger) -> None: args.hyperparams_dict, ) + # Seed the frozen opponent in all worker envs when self-play is enabled + load_opponent = getattr(args, "load_opponent_model", "") + if load_opponent and getattr(args, "selfplay", False): + opponent_path = ensure_zip_path(load_opponent) + if os.path.exists(opponent_path): + com_print(f"[SelfPlay] Seeding initial opponent model: {opponent_path}") + self.env.env_method("set_opponent_model", opponent_path) + else: + com_print(f"[SelfPlay] Warning: opponent model not found at {opponent_path}") + self.model = init_model( self.output_fullpath, args.load_p1_model, @@ -796,7 +901,7 @@ def build_callback(self, shared_state: LiveTrainingState) -> BaseCallback: self.args.hyperparams_dict, use_sticky_action=False, ) - return LiveTrainingCallback( + live_cb = LiveTrainingCallback( shared_state=shared_state, eval_env=eval_env, eval_freq=self.args.live_eval_freq, @@ -804,6 +909,19 @@ def build_callback(self, shared_state: LiveTrainingState) -> BaseCallback: verbose=1 if self.args.alg_verbose else 0, ) + if getattr(self.args, "selfplay", False): + snapshot_dir = os.path.join(self.output_fullpath, "opponent_snapshots") + snapshot_cb = OpponentSnapshotCallback( + snapshot_dir=snapshot_dir, + snapshot_freq=getattr(self.args, "opponent_snapshot_freq", 50_000), + pool_size=getattr(self.args, "opponent_pool_size", 5), + best_model_path=self.best_model_savepath, + verbose=1 if self.args.alg_verbose else 0, + ) + return CallbackList([live_cb, snapshot_cb]) + + return live_cb + def train(self, callback: BaseCallback) -> None: if self.args.alg == "es": raise NotImplementedError("Live training does not currently support Evolution Strategies.") @@ -882,6 +1000,26 @@ def build_parser() -> argparse.ArgumentParser: help="Extra seconds to wait after each display step (default 0 for ~60 FPS playback).", ) + # Self-play arguments + parser.add_argument( + "--load_opponent_model", + type=str, + default="", + help="Path to an initial frozen-opponent checkpoint for self-play training.", + ) + parser.add_argument( + "--opponent_snapshot_freq", + type=int, + default=50_000, + help="Timestep interval between opponent checkpoint snapshots during self-play.", + ) + parser.add_argument( + "--opponent_pool_size", + type=int, + default=5, + help="Maximum number of opponent checkpoints to keep in the rotation pool.", + ) + return parser