diff --git a/pyproject.toml b/pyproject.toml index de9b2b7..6f58497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,11 +41,14 @@ dev = [ "pyright>=1.1.406", "ruff>=0.14.1", "pytest>=8.0.0", + "pytest-asyncio>=0.24.0", ] examples = [ "transformers>=4.57.3", ] [tool.pytest.ini_options] -testpaths = ["src"] +testpaths = ["tests"] python_files = ["test_*.py", "*_test.py"] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/src/ares/contrib/__init__.py b/src/ares/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ares/contrib/rl/__init__.py b/src/ares/contrib/rl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ares/contrib/rl/replay_buffer.py b/src/ares/contrib/rl/replay_buffer.py new file mode 100644 index 0000000..a63dd8c --- /dev/null +++ b/src/ares/contrib/rl/replay_buffer.py @@ -0,0 +1,613 @@ +"""Episode Replay Buffer for Multi-Agent Reinforcement Learning. + +This module provides an asyncio-safe replay buffer that supports: +- Concurrent experience collection from multiple agents +- Episode-based storage with explicit start/end control +- N-step return sampling with configurable gamma +- Capacity management with automatic eviction of oldest episodes + +Storage Design: + Episodes store per-timestep arrays: observations[t], actions[t], rewards[t]. + +Usage Example: + ```python + import asyncio + from ares.contrib.rl import replay_buffer + + # Create buffer with capacity limits + buffer = replay_buffer.EpisodeReplayBuffer(max_episodes=1000, max_steps=100000) + + # Start an episode (episode_id is auto-generated) + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Collect experience (initial observation before first action) + obs_0 = {"obs": [1, 2, 3]} + + # Take action, observe reward and next obs + action_0 = 0 + reward_0 = 1.0 + obs_1 = {"obs": [4, 5, 6]} + + await buffer.append_observation_action_reward( + episode_id, obs_0, action_0, reward_0 + ) + + # Continue... + action_1 = 1 + reward_1 = 2.0 + obs_2 = {"obs": [7, 8, 9]} + + await buffer.append_observation_action_reward( + episode_id, obs_1, action_1, reward_1 + ) + + # End episode with final observation + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=obs_2) + + # Sample n-step batches + samples = await buffer.sample_n_step(batch_size=32, n=3, gamma=0.99) + for sample in samples: + # sample.obs_t, sample.action_t, sample.rewards_seq + # sample.next_obs, sample.terminal, etc. + discounted_return = replay_buffer.compute_discounted_return( + sample.rewards_seq, gamma=0.99 + ) + ``` + +Thread Safety and Async Usage: + All public methods are async. This buffer is designed for single-threaded + asyncio usage and does not provide internal synchronization. If you need + concurrent access from multiple asyncio tasks, you should manage + synchronization externally. +""" + +import collections +from collections import deque +import dataclasses +import random +import time +from typing import Any, Literal +import uuid + +# Type of episode status +EpisodeStatus = Literal["IN_PROGRESS", "COMPLETED"] + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Episode[ObservationType, ActionType]: + """An episode containing sequences of observations, actions, and rewards. + + Storage format: + - observations: [obs_0, obs_1, ..., obs_T] (length T+1) + - actions: [a_0, a_1, ..., a_{T-1}] (length T) + - rewards: [r_0, r_1, ..., r_{T-1}] (length T) + + A full transition at timestep t consists of (obs_t, action_t, reward_t, obs_{t+1}). + The next observation obs_{t+1} is stored at observations[t+1]. + + Attributes: + episode_id: Unique identifier for this episode + agent_id: Identifier of the agent that generated this episode + observations: List of observations in temporal order + actions: List of actions taken + rewards: List of rewards received + status: Current status of the episode + start_time: Timestamp when episode started (for eviction policy) + """ + + episode_id: str + agent_id: str + observations: list[ObservationType] = dataclasses.field(default_factory=list) + actions: list[ActionType] = dataclasses.field(default_factory=list) + rewards: list[float] = dataclasses.field(default_factory=list) + status: EpisodeStatus = "IN_PROGRESS" + start_time: float = dataclasses.field(default_factory=time.time) + + def __len__(self) -> int: + """Return the number of complete transitions (with both obs_t and obs_{t+1} available). + + A complete transition requires observations[t] and observations[t+1]. + Returns max(len(observations) - 1, 0). + """ + return max(len(self.observations) - 1, 0) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ReplaySample[ObservationType, ActionType]: + """A sampled n-step experience for training. + + The sample captures a trajectory segment starting at time t: + obs_t, action_t, [r_t, r_{t+1}, ..., r_{t+m-1}], obs_{t+m} + + where m <= n is the actual number of steps (truncated at episode boundary). + + Attributes: + episode_id: ID of the source episode + agent_id: ID of the agent that generated this experience + obs_t: The observation at time t + action_t: The action taken at time t + rewards_seq: Sequence of rewards [r_t, r_{t+1}, ..., r_{t+m-1}] (length m) + next_obs: The observation at time t+m (obs_{t+m}) + terminal: True if episode terminated within the n-step window + next_discount: Discount factor to apply to bootstrap value at next_obs. + This is gamma^m where m is actual_n. When terminal=True, this + should be 0 (no bootstrap). When terminal=False, this is the + discount to apply to the value estimate at next_obs. + discount_powers: [gamma^0, gamma^1, ..., gamma^{m-1}] for computing returns + start_step: The starting step index t + actual_n: The actual number of steps m (may be < n if episode ends) + gamma: The discount factor used + """ + + episode_id: str + agent_id: str + obs_t: ObservationType + action_t: ActionType + rewards_seq: list[float] + next_obs: ObservationType + terminal: bool + next_discount: float + discount_powers: list[float] + start_step: int + actual_n: int + gamma: float + + @property + def reward(self) -> float: + """Return the computed discounted return for this sample. + + This is a convenience property that computes the discounted return + from the rewards sequence using the stored gamma value. + + Returns: + The discounted return: sum_{k=0}^{n-1} gamma^k * r_k + """ + return compute_discounted_return(self.rewards_seq, self.gamma) + + +# Backward compatibility alias +NStepSample = ReplaySample + + +def compute_discounted_return(rewards: list[float], gamma: float) -> float: + """Compute the discounted return from a sequence of rewards. + + G = sum_{k=0}^{n-1} gamma^k * r_k + + Args: + rewards: Sequence of rewards [r_0, r_1, ..., r_{n-1}] + gamma: Discount factor in (0, 1] + + Returns: + The discounted return + """ + return sum(gamma**k * r for k, r in enumerate(rewards)) + + +class EpisodeReplayBuffer: + """Replay buffer for episodic reinforcement learning. + + This buffer stores complete episodes and supports n-step sampling with + proper handling of episode boundaries. It manages capacity by evicting + oldest finished episodes first, then oldest in-progress episodes if needed. + + Sampling: + Uniform sampling over all valid time steps (experiences) across episodes. + Each valid step (obs_t, action_t, reward_t) has equal probability. + The implementation uses O(num_episodes) scan to build episode weights, + then O(num_episodes) weighted sampling for each sample, avoiding the + O(num_episodes * steps_per_episode) cost of enumerating all positions. + + Concurrency: + This buffer is designed for single-threaded usage. If you need + concurrent access from multiple asyncio tasks, you should manage + synchronization externally. + + Capacity Management: + - max_episodes: Maximum number of episodes to store + - max_steps: Maximum total number of transitions across all episodes + - Eviction policy: oldest finished episodes first, then oldest in-progress + - Eviction updates sampling counts to maintain uniform distribution + """ + + def __init__( + self, + max_episodes: int | None = None, + max_steps: int | None = None, + ): + """Initialize the replay buffer. + + Args: + max_episodes: Maximum number of episodes to store (None = unlimited) + max_steps: Maximum total transitions to store (None = unlimited) + """ + self._episodes: dict[str, Episode[Any, Any]] = {} + self._max_episodes = max_episodes + self._max_steps = max_steps + self._total_steps = 0 + + # Track episodes by agent for potential future use + self._episodes_by_agent: dict[str, list[str]] = collections.defaultdict(list) + + # Track episode IDs in insertion order for efficient sampling and eviction + # This deque parallels self._episodes and enables O(1) oldest episode access + self._episode_order: deque[str] = deque() + + # Track valid step counts for each episode in parallel with _episode_order + # This avoids O(num_episodes) scan during sampling by maintaining counts + # as episodes are added/extended/evicted + self._episode_valid_counts: deque[int] = deque() + + async def start_episode( + self, + agent_id: str, + ) -> str: + """Start a new episode. + + Args: + agent_id: Identifier for the agent + + Returns: + The episode_id for this episode + """ + episode_id = str(uuid.uuid4()) + + episode = Episode(episode_id=episode_id, agent_id=agent_id) + self._episodes[episode_id] = episode + self._episodes_by_agent[agent_id].append(episode_id) + self._episode_order.append(episode_id) + self._episode_valid_counts.append(0) # New episode starts with 0 valid steps + + # Check capacity and evict if needed + await self._evict_if_needed() + + return episode_id + + async def append_observation_action_reward( + self, + episode_id: str, + observation: Any, + action: Any, + reward: float, + ) -> None: + """Append an observation, action, and reward to an episode. + + At time step t, call this with (obs_t, action_t, reward_t). + The observation obs_t should be the state in which action_t was taken, + and reward_t is the immediate reward received. + + Note: You should also store the final observation after the last action + by calling this method one more time or handling it specially when + ending the episode. The typical pattern is: + + 1. Observe obs_0 (initial state) + 2. Take action_0, receive reward_0, observe obs_1 + -> append(obs_0, action_0, reward_0) + 3. Take action_1, receive reward_1, observe obs_2 + -> append(obs_1, action_1, reward_1) + ... + T. Episode ends at obs_T + -> Store obs_T in observations but no action/reward + + Args: + episode_id: The episode to append to + observation: The observation at time t + action: The action taken at time t + reward: The reward received at time t + + Raises: + ValueError: If episode doesn't exist or is already finished + """ + if episode_id not in self._episodes: + raise ValueError(f"Episode {episode_id} not found") + + episode = self._episodes[episode_id] + + if episode.status != "IN_PROGRESS": + raise ValueError(f"Cannot append to finished episode {episode_id} (status: {episode.status})") + + # Store observation (if this is the first call, obs_0) + # For subsequent calls, we're storing obs_t where action_t was taken + if len(episode.observations) == len(episode.actions): + # We need to add the observation for this timestep + episode.observations.append(observation) + + episode.actions.append(action) + episode.rewards.append(reward) + self._total_steps += 1 + + # Update valid count for this episode + # After adding an action/reward, we need to check if a valid step was created + # A valid step requires both obs_t and obs_{t+1} to be available + new_valid_count = self._get_valid_step_count(episode) + episode_idx = self._episode_order.index(episode_id) + self._episode_valid_counts[episode_idx] = new_valid_count + + # Check step capacity + await self._evict_if_needed() + + async def end_episode( + self, + episode_id: str, + status: EpisodeStatus = "COMPLETED", + final_observation: Any = None, + ) -> None: + """Mark an episode as finished. + + Args: + episode_id: The episode to end + status: Episode status (should be "COMPLETED") + final_observation: Final observation obs_T after last action. Required if not + already appended via append_observation_action_reward. + + Raises: + ValueError: If episode doesn't exist, is already finished, + status is IN_PROGRESS, or final_observation is required but not provided + """ + if episode_id not in self._episodes: + raise ValueError(f"Episode {episode_id} not found") + + episode = self._episodes[episode_id] + + if episode.status != "IN_PROGRESS": + raise ValueError(f"Episode {episode_id} is already finished") + + if status == "IN_PROGRESS": + raise ValueError("Cannot end episode with status IN_PROGRESS") + + # Validation: If observations length equals actions length, + # the final observation hasn't been added yet, so it must be provided + if len(episode.observations) == len(episode.actions): + if final_observation is None: + raise ValueError( + f"Episode {episode_id} requires final_observation: " + f"observations length ({len(episode.observations)}) equals " + f"actions length ({len(episode.actions)})" + ) + episode.observations.append(final_observation) + elif final_observation is not None: + # If final_observation is provided but not needed, append it anyway + episode.observations.append(final_observation) + + # Update status using object.__setattr__ since Episode dataclass is frozen + object.__setattr__(episode, "status", status) + + # Update valid count for this episode (status change affects valid count) + new_valid_count = self._get_valid_step_count(episode) + episode_idx = self._episode_order.index(episode_id) + self._episode_valid_counts[episode_idx] = new_valid_count + + def _get_valid_step_count(self, episode: Episode) -> int: + """Get the number of valid starting positions for sampling in an episode. + + Args: + episode: The episode to check + + Returns: + Number of valid starting positions (0 if none) + """ + num_steps = len(episode.actions) + if num_steps == 0: + return 0 + + # For COMPLETED episodes, all steps are valid + if episode.status == "COMPLETED": + return num_steps + + # For IN_PROGRESS episodes, only steps with next observation available + # A step t is valid if observations[t+1] exists + # Since len(observations) can be at most len(actions) + 1, + # valid steps are those where t+1 < len(observations) + valid_count = max(0, len(episode.observations) - 1) + return min(valid_count, num_steps) + + async def sample_n_step( + self, + batch_size: int, + n: int, + gamma: float, + ) -> list[ReplaySample]: + """Sample n-step experiences uniformly from the buffer. + + Sampling is uniform over all valid time steps across all episodes. + Each step (obs_t, action_t, reward_t) has equal probability. + + N-step windows never cross episode boundaries. If fewer than n steps + remain in the episode, the sample is truncated to the available steps. + + Args: + batch_size: Number of samples to return + n: Number of steps for n-step returns + gamma: Discount factor for computing returns + + Returns: + List of n-step samples (may be less than batch_size if insufficient data) + + Raises: + ValueError: If n < 1 or gamma not in (0, 1] + """ + if n < 1: + raise ValueError(f"n must be >= 1, got {n}") + if not 0 < gamma <= 1: + raise ValueError(f"gamma must be in (0, 1], got {gamma}") + + # Build episode ranges using pre-computed valid counts from the deque + # This avoids O(num_episodes) scan to compute valid counts + episode_ranges: list[tuple[int, int, str]] = [] # (start_pos, end_pos, episode_id) + cumulative_pos = 0 + + for episode_id, valid_count in zip(self._episode_order, self._episode_valid_counts, strict=True): + if valid_count > 0: + episode_ranges.append((cumulative_pos, cumulative_pos + valid_count, episode_id)) + cumulative_pos += valid_count + + if not episode_ranges: + return [] + + total_valid_positions = cumulative_pos + + # Sample uniformly without replacement + num_samples = min(batch_size, total_valid_positions) + + # Generate unique random positions + sampled_global_positions = random.sample(range(total_valid_positions), num_samples) + + # Build n-step samples by mapping global positions back to (episode_id, step_idx) + samples: list[ReplaySample] = [] + for global_pos in sampled_global_positions: + # Find the episode containing this position + for start_pos, end_pos, episode_id in episode_ranges: + if start_pos <= global_pos < end_pos: + # Convert global position to local step index within episode + start_idx = global_pos - start_pos + + sample = self._build_n_step_sample( + episode_id=episode_id, + start_idx=start_idx, + n=n, + gamma=gamma, + ) + samples.append(sample) + break + + return samples + + def _build_n_step_sample( + self, + episode_id: str, + start_idx: int, + n: int, + gamma: float, + ) -> ReplaySample: + """Build an n-step sample starting from a given position. + + Never crosses episode boundary; truncates if fewer than n steps remain. + """ + episode = self._episodes[episode_id] + + num_steps = len(episode.actions) + + # Determine actual window size (truncate at episode boundary) + end_idx = min(start_idx + n, num_steps) + actual_n = end_idx - start_idx + + # Extract data for the window [start_idx, end_idx) + obs_t = episode.observations[start_idx] + action_t = episode.actions[start_idx] + rewards_seq = episode.rewards[start_idx:end_idx] + + # next_obs is observation at end_idx + # We can safely access this because _get_valid_step_count ensures + # that only positions with next_obs available are sampled + next_obs = episode.observations[end_idx] + + # Check if episode ended within the window + terminal = (end_idx == num_steps) and (episode.status != "IN_PROGRESS") + + # Compute next_discount: gamma^m when not terminal, 0 when terminal + next_discount = 0.0 if terminal else gamma**actual_n + + # Compute discount powers + discount_powers = [gamma**k for k in range(actual_n)] + + return ReplaySample( + episode_id=episode_id, + agent_id=episode.agent_id, + obs_t=obs_t, + action_t=action_t, + rewards_seq=rewards_seq, + next_obs=next_obs, + terminal=terminal, + next_discount=next_discount, + discount_powers=discount_powers, + start_step=start_idx, + actual_n=actual_n, + gamma=gamma, + ) + + async def _evict_if_needed(self) -> None: + """Evict oldest episodes if capacity limits are exceeded. + + Eviction policy: + 1. First evict oldest finished episodes (terminal or truncated) + 2. If still over capacity, evict oldest in-progress episodes + + Eviction updates the total step count to maintain correct uniform + sampling statistics. + """ + # Check episode capacity + if self._max_episodes is not None: + while len(self._episodes) > self._max_episodes: + self._evict_oldest_episode() + + # Check step capacity + if self._max_steps is not None: + while self._total_steps > self._max_steps: + if not self._episodes: + break + self._evict_oldest_episode() + + def _evict_oldest_episode(self) -> None: + """Evict the oldest episode from the buffer.""" + if not self._episodes: + return + + # Separate finished and in-progress episodes + finished_episodes: list[tuple[str, Episode]] = [] + in_progress_episodes: list[tuple[str, Episode]] = [] + + for episode_id, episode in self._episodes.items(): + if episode.status == "IN_PROGRESS": + in_progress_episodes.append((episode_id, episode)) + else: + finished_episodes.append((episode_id, episode)) + + # Evict oldest finished episode first + if finished_episodes: + oldest = min(finished_episodes, key=lambda x: x[1].start_time) + episode_id = oldest[0] + else: + # No finished episodes, evict oldest in-progress + oldest = min(in_progress_episodes, key=lambda x: x[1].start_time) + episode_id = oldest[0] + + # Remove the episode + episode = self._episodes.pop(episode_id) + self._total_steps -= len(episode) + + # Update agent tracking + agent_id = episode.agent_id + if agent_id in self._episodes_by_agent: + self._episodes_by_agent[agent_id].remove(episode_id) + if not self._episodes_by_agent[agent_id]: + del self._episodes_by_agent[agent_id] + + # Remove from episode order deque and corresponding valid count + episode_idx = self._episode_order.index(episode_id) + self._episode_order.remove(episode_id) + del self._episode_valid_counts[episode_idx] + + async def get_stats(self) -> dict[str, Any]: + """Get statistics about the replay buffer. + + Returns: + Dictionary with buffer statistics + """ + num_in_progress = sum(1 for ep in self._episodes.values() if ep.status == "IN_PROGRESS") + num_completed = sum(1 for ep in self._episodes.values() if ep.status == "COMPLETED") + + return { + "total_episodes": len(self._episodes), + "in_progress": num_in_progress, + "completed": num_completed, + "total_steps": self._total_steps, + "num_agents": len(self._episodes_by_agent), + } + + async def clear(self) -> None: + """Clear all episodes from the buffer.""" + self._episodes.clear() + self._episodes_by_agent.clear() + self._episode_order.clear() + self._episode_valid_counts.clear() + self._total_steps = 0 diff --git a/src/ares/contrib/rl/replay_buffer_test.py b/src/ares/contrib/rl/replay_buffer_test.py new file mode 100644 index 0000000..31f39a1 --- /dev/null +++ b/src/ares/contrib/rl/replay_buffer_test.py @@ -0,0 +1,642 @@ +"""Unit tests for the Episode Replay Buffer.""" + +import asyncio +import random + +import pytest + +import ares.contrib.rl.replay_buffer + + +class TestComputeDiscountedReturn: + """Test the helper function for computing discounted returns.""" + + def test_single_reward(self): + """Test with a single reward.""" + result = ares.contrib.rl.replay_buffer.compute_discounted_return([5.0], gamma=0.99) + assert result == 5.0 + + def test_multiple_rewards(self): + """Test with multiple rewards.""" + rewards = [1.0, 2.0, 3.0] + gamma = 0.9 + expected = 1.0 + 0.9 * 2.0 + 0.81 * 3.0 + result = ares.contrib.rl.replay_buffer.compute_discounted_return(rewards, gamma) + assert abs(result - expected) < 1e-6 + + def test_gamma_one(self): + """Test with gamma=1 (undiscounted).""" + rewards = [1.0, 2.0, 3.0] + result = ares.contrib.rl.replay_buffer.compute_discounted_return(rewards, gamma=1.0) + assert result == 6.0 + + def test_empty_rewards(self): + """Test with empty reward sequence.""" + result = ares.contrib.rl.replay_buffer.compute_discounted_return([], gamma=0.99) + assert result == 0.0 + + +class TestEpisodeLifecycle: + """Test basic episode lifecycle operations.""" + + @pytest.mark.asyncio + async def test_start_episode(self): + """Test starting a new episode.""" + import uuid + + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Verify episode_id is a valid UUID + uuid.UUID(episode_id) + stats = await buffer.get_stats() + assert stats["total_episodes"] == 1 + assert stats["in_progress"] == 1 + + @pytest.mark.asyncio + async def test_append_observation_action_reward(self): + """Test appending experience to an episode.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Add first transition + await buffer.append_observation_action_reward(episode_id, observation=[1, 2, 3], action=0, reward=1.0) + + stats = await buffer.get_stats() + assert stats["total_steps"] == 1 + + # Add second transition + await buffer.append_observation_action_reward(episode_id, observation=[4, 5, 6], action=1, reward=2.0) + + stats = await buffer.get_stats() + assert stats["total_steps"] == 2 + + @pytest.mark.asyncio + async def test_append_to_nonexistent_episode(self): + """Test appending to a non-existent episode raises error.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + with pytest.raises(ValueError, match="not found"): + await buffer.append_observation_action_reward("nonexistent", observation=[1], action=0, reward=0.0) + + @pytest.mark.asyncio + async def test_end_episode_terminal(self): + """Test ending an episode as terminal.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + await buffer.append_observation_action_reward(episode_id, observation=[1], action=0, reward=1.0) + + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[2]) + + stats = await buffer.get_stats() + assert stats["completed"] == 1 + assert stats["in_progress"] == 0 + + @pytest.mark.asyncio + async def test_end_episode_truncated(self): + """Test ending an episode as truncated.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + await buffer.append_observation_action_reward(episode_id, observation=[1], action=0, reward=1.0) + + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[2]) + + stats = await buffer.get_stats() + assert stats["completed"] == 1 + assert stats["in_progress"] == 0 + + @pytest.mark.asyncio + async def test_end_episode_prevents_further_appends(self): + """Test that ending an episode prevents further appends.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + await buffer.append_observation_action_reward(episode_id, observation=[1], action=0, reward=1.0) + + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[2]) + + # Try to append after ending + with pytest.raises(ValueError, match="Cannot append to finished episode"): + await buffer.append_observation_action_reward(episode_id, observation=[3], action=1, reward=2.0) + + @pytest.mark.asyncio + async def test_end_episode_already_finished(self): + """Test that ending an already finished episode raises error.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + await buffer.append_observation_action_reward(episode_id, observation=[1], action=0, reward=1.0) + + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[2]) + + with pytest.raises(ValueError, match="already finished"): + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[3]) + + @pytest.mark.asyncio + async def test_end_episode_with_in_progress_status(self): + """Test that ending with IN_PROGRESS status raises error.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + with pytest.raises(ValueError, match="Cannot end episode with status IN_PROGRESS"): + await buffer.end_episode(episode_id, status="IN_PROGRESS") + + +class TestStorageFormat: + """Test that storage format avoids duplication of states.""" + + @pytest.mark.asyncio + async def test_no_state_duplication(self): + """Test that next_obs is derived from subsequent observation, not duplicated.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Build a simple episode + obs_0 = {"data": [0]} + obs_1 = {"data": [1]} + obs_2 = {"data": [2]} + + await buffer.append_observation_action_reward(episode_id, observation=obs_0, action=0, reward=1.0) + await buffer.append_observation_action_reward(episode_id, observation=obs_1, action=1, reward=2.0) + + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=obs_2) + + # Sample and verify next_obs matches subsequent observation + samples = await buffer.sample_n_step(batch_size=2, n=1, gamma=0.99) + + # Find sample starting at t=0 + sample_0 = next(s for s in samples if s.start_step == 0) + assert sample_0.obs_t == obs_0 + assert sample_0.next_obs == obs_1 # Derived from observations[1] + + # Find sample starting at t=1 + sample_1 = next(s for s in samples if s.start_step == 1) + assert sample_1.obs_t == obs_1 + assert sample_1.next_obs == obs_2 # Derived from observations[2] + + +class TestConcurrency: + """Test concurrent episode appending.""" + + @pytest.mark.asyncio + async def test_concurrent_episode_appends(self): + """Test multiple episodes appended concurrently via asyncio tasks.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + async def fill_episode(agent_id: str, num_steps: int): + """Fill an episode with num_steps transitions.""" + episode_id = await buffer.start_episode(agent_id=agent_id) + for t in range(num_steps): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=float(t)) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[num_steps]) + + # Run multiple episodes concurrently + tasks = [ + fill_episode("agent_0", 10), + fill_episode("agent_1", 20), + fill_episode("agent_2", 15), + ] + await asyncio.gather(*tasks) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 3 + assert stats["total_steps"] == 10 + 20 + 15 + assert stats["completed"] == 3 + + @pytest.mark.asyncio + async def test_concurrent_writes_and_reads(self): + """Test concurrent writes (appends) and reads (sampling).""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + # Pre-fill some episodes + for i in range(3): + episode_id = await buffer.start_episode(agent_id=f"agent_{i}") + for t in range(10): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=1.0) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[10]) + + async def writer(): + """Write new episodes.""" + for i in range(3, 6): + episode_id = await buffer.start_episode(agent_id=f"agent_{i}") + for t in range(5): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=1.0) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[5]) + await asyncio.sleep(0.001) # Small delay to allow interleaving + + async def reader(): + """Sample from the buffer.""" + for _ in range(10): + samples = await buffer.sample_n_step(batch_size=5, n=3, gamma=0.99) + assert len(samples) > 0 + await asyncio.sleep(0.001) + + # Run writer and reader concurrently + await asyncio.gather(writer(), reader()) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 6 + + +class TestUniformSampling: + """Test uniform sampling over experiences.""" + + @pytest.mark.asyncio + async def test_uniform_over_steps_not_episodes(self): + """ + Test that sampling is uniform over steps, not episodes. + + Create episodes with different lengths and verify that longer episodes + contribute proportionally more samples. + """ + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + # Episode 1: 10 steps + ep1 = await buffer.start_episode(agent_id="agent_0") + for t in range(10): + await buffer.append_observation_action_reward(ep1, observation={"ep": 1, "t": t}, action=t, reward=1.0) + await buffer.end_episode(ep1, status="COMPLETED", final_observation={"ep": 1, "t": 10}) + + # Episode 2: 30 steps (3x longer) + ep2 = await buffer.start_episode(agent_id="agent_1") + for t in range(30): + await buffer.append_observation_action_reward(ep2, observation={"ep": 2, "t": t}, action=t, reward=1.0) + await buffer.end_episode(ep2, status="COMPLETED", final_observation={"ep": 2, "t": 30}) + + # Sample many times and count samples from each episode + num_samples = 1000 + samples = await buffer.sample_n_step(batch_size=num_samples, n=1, gamma=0.99) + + ep1_count = sum(1 for s in samples if s.episode_id == ep1) + ep2_count = sum(1 for s in samples if s.episode_id == ep2) + + # Expect ratio close to 1:3 (episode 2 is 3x longer) + # Allow some variance due to randomness + ratio = ep2_count / ep1_count if ep1_count > 0 else 0 + assert 2.0 < ratio < 4.0, f"Expected ratio ~3, got {ratio}" + + @pytest.mark.asyncio + async def test_all_steps_have_equal_probability(self): + """Test that all steps across episodes are equally likely.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + # Create 3 episodes with 10 steps each + episode_ids = [] + for i in range(3): + ep = await buffer.start_episode(agent_id=f"agent_{i}") + episode_ids.append(ep) + for t in range(10): + await buffer.append_observation_action_reward(ep, observation=[i, t], action=t, reward=1.0) + await buffer.end_episode(ep, status="COMPLETED", final_observation=[i, 10]) + + # Sample exhaustively (all 30 steps) + samples = await buffer.sample_n_step(batch_size=30, n=1, gamma=0.99) + + # Check we got all 30 unique steps + assert len(samples) == 30 + + # Verify each episode contributes 10 samples + for ep_id in episode_ids: + count = sum(1 for s in samples if s.episode_id == ep_id) + assert count == 10 + + +class TestNStepSampling: + """Test n-step sampling with boundary handling.""" + + @pytest.mark.asyncio + async def test_n_step_basic(self): + """Test basic n-step sampling.""" + random.seed(42) + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Create episode with 5 steps + for t in range(5): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=float(t + 1)) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[5]) + + # Sample with n=3 starting from t=0 + samples = await buffer.sample_n_step(batch_size=1, n=3, gamma=0.9) + sample = samples[0] + + # Should get 3 steps: rewards [1, 2, 3] + assert sample.obs_t == [0] + assert sample.action_t == 0 + assert sample.rewards_seq == [1.0, 2.0, 3.0] + assert sample.next_obs == [3] + assert sample.actual_n == 3 + assert not sample.terminal + + @pytest.mark.asyncio + async def test_n_step_truncation_at_boundary(self): + """Test that n-step sample truncates at episode boundary.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Create episode with only 3 steps + for t in range(3): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=float(t + 1)) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[3]) + + # Request n=5 but only 3 steps available from t=0 + # Should get all 3 steps and truncate + samples = await buffer.sample_n_step(batch_size=10, n=5, gamma=0.9) + + # Find sample starting at t=0 + sample_0 = next(s for s in samples if s.start_step == 0) + assert sample_0.actual_n == 3 + assert sample_0.rewards_seq == [1.0, 2.0, 3.0] + assert sample_0.terminal # Episode ended + + @pytest.mark.asyncio + async def test_n_step_never_crosses_episode_boundary(self): + """Test that n-step sampling never crosses episode boundaries.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + # Create two short episodes + ep1 = await buffer.start_episode(agent_id="agent_0") + for t in range(3): + await buffer.append_observation_action_reward(ep1, observation={"ep": 1, "t": t}, action=t, reward=1.0) + await buffer.end_episode(ep1, status="COMPLETED", final_observation={"ep": 1, "t": 3}) + + ep2 = await buffer.start_episode(agent_id="agent_1") + for t in range(3): + await buffer.append_observation_action_reward(ep2, observation={"ep": 2, "t": t}, action=t, reward=2.0) + await buffer.end_episode(ep2, status="COMPLETED", final_observation={"ep": 2, "t": 3}) + + # Sample with large n + samples = await buffer.sample_n_step(batch_size=10, n=10, gamma=0.9) + + # Verify no sample crosses episodes + for sample in samples: + # Check that all rewards come from the same episode + if sample.episode_id == ep1: + # All observations should have ep=1 + assert sample.obs_t["ep"] == 1 + assert sample.next_obs["ep"] == 1 + else: + assert sample.obs_t["ep"] == 2 + assert sample.next_obs["ep"] == 2 + + # No sample should exceed 3 steps (episode length) + assert sample.actual_n <= 3 + + @pytest.mark.asyncio + async def test_n_step_near_end_truncates(self): + """Test n-step sampling near episode end truncates properly.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + # Create episode with 5 steps + for t in range(5): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=float(t + 1)) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[5]) + + # Sample starting from t=3 with n=3 + # Should only get 2 steps (t=3, t=4) because episode has only 5 steps total + samples = await buffer.sample_n_step(batch_size=10, n=3, gamma=0.9) + + # Find sample starting at t=3 + sample_3 = next(s for s in samples if s.start_step == 3) + assert sample_3.actual_n == 2 + assert sample_3.rewards_seq == [4.0, 5.0] + assert sample_3.terminal + + @pytest.mark.asyncio + async def test_n_step_discount_powers(self): + """Test that discount powers are correctly computed.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + episode_id = await buffer.start_episode(agent_id="agent_0") + + for t in range(5): + await buffer.append_observation_action_reward(episode_id, observation=[t], action=t, reward=1.0) + await buffer.end_episode(episode_id, status="COMPLETED", final_observation=[5]) + + gamma = 0.9 + samples = await buffer.sample_n_step(batch_size=1, n=4, gamma=gamma) + sample = samples[0] + + expected_powers = [gamma**k for k in range(sample.actual_n)] + assert sample.discount_powers == expected_powers + + @pytest.mark.asyncio + async def test_n_step_terminal_and_next_discount(self): + """Test that terminal flag and next_discount are set correctly.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + # Completed episode with 5 steps + ep1 = await buffer.start_episode(agent_id="agent_0") + for t in range(5): + await buffer.append_observation_action_reward(ep1, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep1, status="COMPLETED", final_observation=[5]) + + # Sample with n=2 + gamma = 0.9 + samples = await buffer.sample_n_step(batch_size=10, n=2, gamma=gamma) + + # Find sample at t=0 (not terminal, should have next_discount=gamma^2) + sample_0 = next(s for s in samples if s.start_step == 0) + assert not sample_0.terminal + assert abs(sample_0.next_discount - gamma**2) < 1e-6 + + # Find sample at t=2 (not terminal, should have next_discount=gamma^2) + sample_2 = next(s for s in samples if s.start_step == 2) + assert not sample_2.terminal + assert abs(sample_2.next_discount - gamma**2) < 1e-6 + + # Find sample at t=4 (terminal with actual_n=1, should have next_discount=0) + sample_4 = next(s for s in samples if s.start_step == 4) + assert sample_4.terminal + assert sample_4.actual_n == 1 + assert sample_4.next_discount == 0.0 + + +class TestCapacityAndEviction: + """Test capacity management and eviction behavior.""" + + @pytest.mark.asyncio + async def test_max_episodes_eviction(self): + """Test that max_episodes limit triggers eviction.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer(max_episodes=3) + + # Add 3 episodes (at capacity) + for i in range(3): + ep = await buffer.start_episode(agent_id=f"agent_{i}") + await buffer.append_observation_action_reward(ep, observation=[i], action=i, reward=1.0) + await buffer.end_episode(ep, status="COMPLETED", final_observation=[i + 1]) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 3 + + # Add 4th episode, should evict oldest + ep4 = await buffer.start_episode(agent_id="agent_3") + await buffer.append_observation_action_reward(ep4, observation=[3], action=3, reward=1.0) + await buffer.end_episode(ep4, status="COMPLETED", final_observation=[4]) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 3 # Still at max + + @pytest.mark.asyncio + async def test_max_steps_eviction(self): + """Test that max_steps limit triggers eviction.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer(max_steps=10) + + # Add episodes totaling 10 steps + ep1 = await buffer.start_episode(agent_id="agent_0") + for t in range(5): + await buffer.append_observation_action_reward(ep1, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep1, status="COMPLETED", final_observation=[5]) + + ep2 = await buffer.start_episode(agent_id="agent_1") + for t in range(5): + await buffer.append_observation_action_reward(ep2, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep2, status="COMPLETED", final_observation=[5]) + + stats = await buffer.get_stats() + assert stats["total_steps"] == 10 + + # Add more steps, should trigger eviction + ep3 = await buffer.start_episode(agent_id="agent_2") + for t in range(3): + await buffer.append_observation_action_reward(ep3, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep3, status="COMPLETED", final_observation=[3]) + + stats = await buffer.get_stats() + # Should have evicted ep1, keeping ep2 and ep3 + assert stats["total_steps"] <= 10 + + @pytest.mark.asyncio + async def test_eviction_prefers_finished_episodes(self): + """Test that eviction prefers finished episodes over in-progress.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer(max_episodes=3) + + # Add 2 finished episodes + for i in range(2): + ep = await buffer.start_episode(agent_id=f"agent_{i}") + await buffer.append_observation_action_reward(ep, observation=[i], action=i, reward=1.0) + await buffer.end_episode(ep, status="COMPLETED", final_observation=[i + 1]) + + # Add 1 in-progress episode + ep_in_progress = await buffer.start_episode(agent_id="agent_in_progress") + await buffer.append_observation_action_reward(ep_in_progress, observation=[99], action=99, reward=1.0) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 3 + assert stats["in_progress"] == 1 + assert stats["completed"] == 2 + + # Add another episode, should evict oldest finished, not in-progress + ep_new = await buffer.start_episode(agent_id="agent_new") + await buffer.append_observation_action_reward(ep_new, observation=[100], action=100, reward=1.0) + await buffer.end_episode(ep_new, status="COMPLETED", final_observation=[101]) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 3 + assert stats["in_progress"] == 1 # In-progress episode should still be there + + @pytest.mark.asyncio + async def test_eviction_updates_sampling_counts(self): + """Test that eviction correctly updates total_steps for sampling.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer(max_episodes=2) + + # Add 2 episodes + ep1 = await buffer.start_episode(agent_id="agent_0") + for t in range(10): + await buffer.append_observation_action_reward(ep1, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep1, status="COMPLETED", final_observation=[10]) + + ep2 = await buffer.start_episode(agent_id="agent_1") + for t in range(5): + await buffer.append_observation_action_reward(ep2, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep2, status="COMPLETED", final_observation=[5]) + + stats = await buffer.get_stats() + assert stats["total_steps"] == 15 + + # Add 3rd episode, should evict ep1 + ep3 = await buffer.start_episode(agent_id="agent_2") + for t in range(7): + await buffer.append_observation_action_reward(ep3, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep3, status="COMPLETED", final_observation=[7]) + + stats = await buffer.get_stats() + # Should have ep2 (5 steps) + ep3 (7 steps) = 12 steps + assert stats["total_steps"] == 12 + + # Sampling should still work correctly + samples = await buffer.sample_n_step(batch_size=12, n=1, gamma=0.9) + assert len(samples) == 12 + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + @pytest.mark.asyncio + async def test_sample_empty_buffer(self): + """Test sampling from an empty buffer returns empty list.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + samples = await buffer.sample_n_step(batch_size=10, n=3, gamma=0.9) + assert samples == [] + + @pytest.mark.asyncio + async def test_sample_with_only_empty_episodes(self): + """Test sampling when episodes have no steps.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + await buffer.start_episode(agent_id="agent_0") + + samples = await buffer.sample_n_step(batch_size=10, n=3, gamma=0.9) + assert samples == [] + + @pytest.mark.asyncio + async def test_sample_n_less_than_one(self): + """Test that n < 1 raises ValueError.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + with pytest.raises(ValueError, match="n must be >= 1"): + await buffer.sample_n_step(batch_size=10, n=0, gamma=0.9) + + @pytest.mark.asyncio + async def test_sample_invalid_gamma(self): + """Test that invalid gamma raises ValueError.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + with pytest.raises(ValueError, match="gamma must be in"): + await buffer.sample_n_step(batch_size=10, n=3, gamma=0.0) + with pytest.raises(ValueError, match="gamma must be in"): + await buffer.sample_n_step(batch_size=10, n=3, gamma=1.5) + + @pytest.mark.asyncio + async def test_clear_buffer(self): + """Test clearing the buffer.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + # Add some episodes + for i in range(3): + ep = await buffer.start_episode(agent_id=f"agent_{i}") + await buffer.append_observation_action_reward(ep, observation=[i], action=i, reward=1.0) + await buffer.end_episode(ep, status="COMPLETED", final_observation=[i + 1]) + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 3 + + # Clear + await buffer.clear() + + stats = await buffer.get_stats() + assert stats["total_episodes"] == 0 + assert stats["total_steps"] == 0 + + @pytest.mark.asyncio + async def test_sample_batch_size_larger_than_available(self): + """Test that sampling returns fewer samples if not enough data.""" + buffer = ares.contrib.rl.replay_buffer.EpisodeReplayBuffer() + + ep = await buffer.start_episode(agent_id="agent_0") + for t in range(3): + await buffer.append_observation_action_reward(ep, observation=[t], action=t, reward=1.0) + await buffer.end_episode(ep, status="COMPLETED", final_observation=[3]) + + # Request 100 samples but only 3 available + samples = await buffer.sample_n_step(batch_size=100, n=1, gamma=0.9) + assert len(samples) == 3 diff --git a/tests/contrib/__init__.py b/tests/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/contrib/rl/__init__.py b/tests/contrib/rl/__init__.py new file mode 100644 index 0000000..e69de29