diff --git a/CHANGELOG.md b/CHANGELOG.md index 017a7bf..850c4e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [0.3.0] - 2026-03-05 + +### Added +- **`IterationState.learner_id`**: New field (`int | str | None`, default `None`) on `IterationState` + identifying which parallel learner produced a given state. Integer index for + `ParallelActiveLearner` and `ParallelReinforcementLearner`; learner name string for + `ParallelUQLearner`. + +### Changed +- **Unified async-iterator API for parallel learners**: `ParallelActiveLearner.start()`, + `ParallelReinforcementLearner.start()`, and `ParallelUQLearner.start()` now return + `AsyncIterator[IterationState]` instead of blocking until all learners finish and returning + `list[Any]`. States stream in real time as each parallel learner completes an iteration, + using the same `async for state in learner.start():` interface as `SequentialActiveLearner`. +- **Shared `_stream_parallel` helper**: The internal `asyncio.Queue`-based fan-in pattern + is extracted into a single module-level async generator in `rose/learner.py`, eliminating + identical code that was previously duplicated across all three parallel learner classes. + +### Deprecated +- **`ParallelActiveLearner.teach()`**, **`ParallelReinforcementLearner.learn()`**, and + **`ParallelUQLearner.teach()`** still work but now internally iterate `start()` and + collect final states into a list. Migrate to `async for state in learner.start():`. + +--- + ## [0.2.0] - 2026-02-27 ### Added diff --git a/docs/user-guide/parallel_learners_docs.md b/docs/user-guide/parallel_learners_docs.md index 10f9ebe..b5cfe5f 100644 --- a/docs/user-guide/parallel_learners_docs.md +++ b/docs/user-guide/parallel_learners_docs.md @@ -1,11 +1,11 @@ # Learners with Parameterization Tutorial -This tutorial demonstrates how to configure and run multiple learning pipelines concurrently using `ParallelActiveLearner`. You’ll learn how to: +This tutorial demonstrates how to configure and run multiple learning pipelines concurrently using `ParallelActiveLearner`. You'll learn how to: - Set up parallel workflows - Configure each learner independently - Use per-iteration and adaptive configurations -- Run learners concurrently with individual stop criteria +- Stream per-learner states in real time as each iteration completes --- @@ -21,12 +21,31 @@ This approach can be applied for both Active and Reinforcement learners (Sequent - **Learner 1**: Per-iteration config — specific checkpoints for tuning - **Learner 2**: Static config — constant settings throughout - All learners run **concurrently and independently** +- States from all learners are **streamed in real time** via `async for` + +--- + +## How the API Works + +`ParallelActiveLearner.start()` returns an **async iterator** that yields an `IterationState` +each time any parallel learner completes an iteration. States arrive in completion order — not +grouped by learner — so you react to results as they happen. + +Each `IterationState` carries a `learner_id` (integer index) identifying which learner produced it: + +```python +async for state in acl.start(parallel_learners=3, max_iter=10): + print(f"Learner {state.learner_id} | iter {state.iteration} | MSE {state.metric_value:.4f}") +``` + +This is the same interface used by `SequentialActiveLearner`, so code that consumes +`IterationState` works identically for both sequential and parallel learners. --- ## Configuration Modes -### 🧠 Adaptive Configuration +### Adaptive Configuration - Receives iteration number `i` - Labeled data: `100 + i*50` @@ -34,7 +53,7 @@ This approach can be applied for both Active and Reinforcement learners (Sequent - Learning rate: `0.01 * (0.9^i)` - Batch size increases gradually, capped at 64 -### 🔁 Per-Iteration Configuration +### Per-Iteration Configuration - Iteration keys (e.g., `0`, `5`, `10`) set exact checkpoints - `-1` is the fallback/default config @@ -73,7 +92,7 @@ acl = ParallelActiveLearner(asyncflow) code_path = f'{sys.executable} {os.getcwd()}' ``` -### 1. Define Workflow Tasks +### 2. Define Workflow Tasks ```python @acl.simulation_task async def simulation(*args, **kwargs): @@ -100,7 +119,7 @@ async def check_mse(*args, **kwargs): ### Approach 1: Static Configuration ```python -results = await acl.start( +async for state in acl.start( parallel_learners=2, max_iter=10, learner_configs=[ @@ -113,12 +132,13 @@ results = await acl.start( training=TaskConfig(kwargs={"--learning_rate": "0.005"}) ) ] -) +): + print(f"[Learner {state.learner_id}] iter={state.iteration} | MSE={state.metric_value}") ``` ### Approach 2: Per-Iteration Configuration ```python -results = await acl.start( +async for state in acl.start( parallel_learners=3, max_iter=15, learner_configs=[ @@ -140,7 +160,8 @@ results = await acl.start( ), None # Default to base task behavior ] -) +): + print(f"[Learner {state.learner_id}] iter={state.iteration} | MSE={state.metric_value}") ``` !!! tip "Per-Iteration Config Keys" @@ -166,7 +187,7 @@ adaptive_train = acl.create_adaptive_schedule('training', } }) -results = await acl.start( +async for state in acl.start( parallel_learners=2, max_iter=20, learner_configs=[ @@ -176,7 +197,8 @@ results = await acl.start( training=TaskConfig(kwargs={"--learning_rate": "0.005"}) ) ] -) +): + print(f"[Learner {state.learner_id}] iter={state.iteration} | MSE={state.metric_value}") ``` ### Full Example: All Approaches Combined @@ -190,7 +212,10 @@ adaptive_sim = acl.create_adaptive_schedule('simulation', } }) -results = await acl.start( +# Collect the final state per learner if needed +final_states = {} + +async for state in acl.start( parallel_learners=3, max_iter=15, learner_configs=[ @@ -206,7 +231,9 @@ results = await acl.start( simulation=TaskConfig(kwargs={"--n_labeled": "300", "--n_features": 4}) ) ] -) +): + print(f"[Learner {state.learner_id}] iter={state.iteration} | MSE={state.metric_value}") + final_states[state.learner_id] = state # keep last state per learner await acl.shutdown() ``` @@ -214,7 +241,11 @@ await acl.shutdown() ### Execution Details !!! note "Concurrent Execution" -All learners run in parallel and independently. The workflow completes when all learners either reach max_iter or meet their stop criterion. +All learners run in parallel and independently. States are yielded in arrival order — whichever learner finishes an iteration first yields next. The loop completes when all learners either reach `max_iter` or meet their stop criterion. + +!!! note "Identifying the Source Learner" +Each `IterationState` has a `learner_id` field (integer index, 0-based) so you can distinguish +which learner produced each state inside the loop. !!! warning "Stop Criteria" Each learner evaluates its own stop condition. One learner stopping does not affect others. @@ -248,10 +279,10 @@ adaptive_config = acl.create_adaptive_schedule('training', lr_decay) ## Next Steps -- 🧪 Try different active learning algorithms per learner +- Try different active learning algorithms per learner -- 🎯 Use per-iteration configs to design curriculum learning +- Use per-iteration configs to design curriculum learning -- 📊 Run parameter sweeps +- Run parameter sweeps across acquisition functions or model architectures -- 🚀 Scale learners to match compute resources +- Scale learners to match compute resources diff --git a/examples/active_learn/parallel/run_me_per_learner_config.py b/examples/active_learn/parallel/run_me_per_learner_config.py index ea448b2..a91ceea 100644 --- a/examples/active_learn/parallel/run_me_per_learner_config.py +++ b/examples/active_learn/parallel/run_me_per_learner_config.py @@ -36,7 +36,7 @@ async def active_learn(*args, task_description={"shell": True}, **kwargs): return f"{code_path}/active.py" # Defining the stop criterion with a metric (MSE in this case) - @al.as_stop_criterion(metric_name=MEAN_SQUARED_ERROR_MSE, threshold=0.1) + @al.as_stop_criterion(metric_name=MEAN_SQUARED_ERROR_MSE, threshold=0.01) async def check_mse(*args, task_description={"shell": True}, **kwargs): return f"{code_path}/check_mse.py" @@ -51,14 +51,15 @@ async def check_mse(*args, task_description={"shell": True}, **kwargs): ) # Start the parallel active learning process - results = await al.start( + async for state in al.start( + max_iter=5, parallel_learners=2, learner_configs=[ LearnerConfig(simulation=adaptive_sim), LearnerConfig(simulation=TaskConfig(kwargs={"--n_labeled": "300", "--n_features": 4})), ], - ) - print(f"Parallel learning completed. Results: {results}") + ): + print(f"Learner {state.learner_id}, iteration {state.iteration}: {state.metric_value}") await al.shutdown() diff --git a/examples/active_learn/parallel/run_me_per_learner_per_iter_config.py b/examples/active_learn/parallel/run_me_per_learner_per_iter_config.py index bbfa887..a99e6b5 100644 --- a/examples/active_learn/parallel/run_me_per_learner_per_iter_config.py +++ b/examples/active_learn/parallel/run_me_per_learner_per_iter_config.py @@ -3,7 +3,7 @@ import sys from radical.asyncflow import WorkflowEngine -from rhapsody.backends import RadicalExecutionBackend +from rhapsody.backends import ConcurrentExecutionBackend from rose import LearnerConfig, TaskConfig from rose.al import ParallelActiveLearner @@ -11,7 +11,7 @@ async def run_al_parallel(): - engine = await RadicalExecutionBackend({"resource": "local.localhost"}) + engine = await ConcurrentExecutionBackend() asyncflow = await WorkflowEngine.create(engine) al = ParallelActiveLearner(asyncflow) @@ -41,7 +41,7 @@ async def check_mse(*args, task_description={"shell": True}, **kwargs): return f"{code_path}/check_mse.py" # Start the parallel active learning process with custom configs - results = await al.start( + async for state in al.start( parallel_learners=3, learner_configs=[ # Learner 0: Same config for all iterations (your current pattern) @@ -57,8 +57,8 @@ async def check_mse(*args, task_description={"shell": True}, **kwargs): ), None, ], - ) - print(f"Parallel learning completed. Results: {results}") + ): + print(f"Learner {state.learner_id}, iteration {state.iteration}: {state.metric_value}") await engine.shutdown() diff --git a/examples/active_learn/parallel/run_me_with_dynamic_config.py b/examples/active_learn/parallel/run_me_with_dynamic_config.py index a647c8f..3392d87 100644 --- a/examples/active_learn/parallel/run_me_with_dynamic_config.py +++ b/examples/active_learn/parallel/run_me_with_dynamic_config.py @@ -52,15 +52,15 @@ async def check_mse(*args, **kwargs): ) # Start the parallel active learning process - results = await al.start( + async for state in al.start( max_iter=1, parallel_learners=2, learner_configs=[ LearnerConfig(simulation=adaptive_sim), LearnerConfig(simulation=TaskConfig(kwargs={"--n_labeled": "300", "--n_features": 2})), ], - ) - print(f"Parallel learning completed. Results: {results}") + ): + print(f"Learner {state.learner_id}, iteration {state.iteration}: {state.metric_value}") await al.shutdown() diff --git a/examples/active_learn/uq/run_me.py b/examples/active_learn/uq/run_me.py index 2061ab2..ad3025f 100644 --- a/examples/active_learn/uq/run_me.py +++ b/examples/active_learn/uq/run_me.py @@ -172,17 +172,20 @@ async def check_uq(*args, **kwargs): ) # Start the UQ active learning process - results = await learner.start( + final_states = {} + async for state in learner.start( learner_names=PIPELINES, model_names=MODELS, learner_configs=learner_configs, max_iter=ITERATIONS, num_predictions=NUM_PREDICTION, - ) + ): + print(f"Learner {state.learner_id}, iteration {state.iteration}: {state.metric_value}") + final_states[state.learner_id] = state print("Learning process is done.") - print(f"Results: {results}") + results = {lid: s.to_dict() for lid, s in final_states.items()} with open(Path(os.getcwd(), "UQ_training_results.json"), "w") as f: json.dump(results, f, indent=4) diff --git a/pyproject.toml b/pyproject.toml index 0e8fd52..81bdc1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ROSE" -version = "0.2.0" +version = "0.3.0" description = "Toolkit to express and execute ML surrogate building workflows on HPC" authors = [ diff --git a/rose/al/active_learner.py b/rose/al/active_learner.py index 288d64f..ed35eb8 100644 --- a/rose/al/active_learner.py +++ b/rose/al/active_learner.py @@ -1,12 +1,13 @@ import asyncio +import dataclasses import itertools import warnings -from collections.abc import AsyncIterator, Coroutine, Iterator +from collections.abc import AsyncIterator, Iterator from typing import Any from radical.asyncflow import WorkflowEngine -from ..learner import IterationState, Learner, LearnerConfig +from ..learner import IterationState, Learner, LearnerConfig, _stream_parallel class SequentialActiveLearner(Learner): @@ -420,12 +421,15 @@ async def start( skip_pre_loop: bool = False, skip_simulation_step: bool = False, learner_configs: list[LearnerConfig | None] | None = None, - ) -> list[Any]: + ) -> AsyncIterator[IterationState]: """Run parallel active learning by launching multiple SequentialActiveLearners. Orchestrates multiple SequentialActiveLearner instances to run concurrently, - each with potentially different configurations. All learners run - independently and their results are collected when all have completed. + each with potentially different configurations. States are streamed in real + time as each learner completes an iteration — use ``async for`` to consume them. + + Each yielded ``IterationState`` includes a ``learner_id`` field (int) indicating + which parallel learner produced it. Args: parallel_learners: Number of parallel learners to run concurrently. @@ -440,13 +444,19 @@ async def start( skip_simulation_step: if True, all learners will skip the simulation step and the learner will consider a simulation pool already exist. - Returns: - list containing the final IterationState from each learner, in the - same order as the learners were launched. + Yields: + IterationState for each iteration of each learner, in arrival order. + Each state has ``learner_id`` set to the integer index of the learner. Raises: ValueError: If parallel_learners < 2 (use SequentialActiveLearner instead). ValueError: If learner_configs length doesn't match parallel_learners. + Exception: Re-raises any exception from a learner after all learners finish. + + Example:: + + async for state in learner.start(parallel_learners=3, max_iter=10): + print(f"Learner {state.learner_id}, iter {state.iteration}: {state.metric_value}") """ if parallel_learners < 2: raise ValueError("For single learner, use SequentialActiveLearner") @@ -456,61 +466,39 @@ async def start( if len(learner_configs) != parallel_learners: raise ValueError("learner_configs length must match parallel_learners") - async def active_learner_workflow(learner_id: int) -> Any: - """Run a single SequentialActiveLearner. - - Internal async function that manages the lifecycle of a single - SequentialActiveLearner within the parallel learning context. - - Args: - learner_id: Unique identifier for this learner instance. - Returns: - The final IterationState from the sequential learner. - - Raises: - Exception: Re-raises any exception from the sequential learner - with additional context about which learner failed. - """ - try: - # Create and configure the sequential learner - sequential_learner: SequentialActiveLearner = self._create_sequential_learner( - learner_id, learner_configs[learner_id] - ) - - # Convert parallel config to sequential config - sequential_config: LearnerConfig | None = self._convert_to_sequential_config( - learner_configs[learner_id] - ) - - # Run the sequential learner by iterating through start() - final_state = None - async for state in sequential_learner.start( - max_iter=max_iter, - skip_pre_loop=skip_pre_loop, - skip_simulation_step=skip_simulation_step, - initial_config=sequential_config, - ): - final_state = state - if self.is_stopped: - sequential_learner.stop() - - # book keep the iteration value from each learner - self.metric_values_per_iteration[f"learner-{learner_id}"] = ( - sequential_learner.metric_values_per_iteration - ) - - return final_state - except Exception as e: - print(f"ActiveLearner-{learner_id}] failed with error: {e}") - raise - print(f"Starting Parallel Active Learning with {parallel_learners} learners") - # Submit all learners asynchronously - learners: list[Coroutine] = [active_learner_workflow(i) for i in range(parallel_learners)] - - # Wait for all learners to complete and collect results - return await asyncio.gather(*learners) + # Factory required: plain closure in a loop would capture the same variable reference. + def make_run_fn(learner_id: int): + async def run_learner(queue: asyncio.Queue) -> None: + try: + sequential_learner: SequentialActiveLearner = self._create_sequential_learner( + learner_id, learner_configs[learner_id] + ) + async for state in sequential_learner.start( + max_iter=max_iter, + skip_pre_loop=skip_pre_loop, + skip_simulation_step=skip_simulation_step, + initial_config=learner_configs[learner_id], + ): + if self.is_stopped: + sequential_learner.stop() + await queue.put( + ("state", dataclasses.replace(state, learner_id=learner_id)) + ) + self.metric_values_per_iteration[f"learner-{learner_id}"] = ( + sequential_learner.metric_values_per_iteration + ) + except Exception as e: + print(f"[ActiveLearner-{learner_id}] failed with error: {e}") + await queue.put(("error", e)) + finally: + await queue.put(("done", None)) + + return run_learner + + async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): + yield state async def teach( self, @@ -541,10 +529,13 @@ async def teach( DeprecationWarning, stacklevel=2, ) - return await self.start( + final_states: dict[int, IterationState | None] = {} + async for state in self.start( parallel_learners=parallel_learners, max_iter=max_iter, skip_pre_loop=skip_pre_loop, skip_simulation_step=skip_simulation_step, learner_configs=learner_configs, - ) + ): + final_states[state.learner_id] = state + return [final_states.get(i) for i in range(parallel_learners)] diff --git a/rose/learner.py b/rose/learner.py index 6fff146..963b8c5 100644 --- a/rose/learner.py +++ b/rose/learner.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from dataclasses import dataclass, field from functools import wraps from typing import Any, Optional @@ -51,6 +51,7 @@ class IterationState: metric_history: list[float] = field(default_factory=list) should_stop: bool = False current_config: Optional["LearnerConfig"] = None + learner_id: int | str | None = None # All domain-specific state goes here state: dict[str, Any] = field(default_factory=dict) @@ -97,12 +98,60 @@ def to_dict(self) -> dict[str, Any]: "metric_threshold": self.metric_threshold, "metric_history": self.metric_history, "should_stop": self.should_stop, + "learner_id": self.learner_id, } # Merge in all state values result.update(self.state) return result +async def _stream_parallel( + run_fns: list[Callable[[asyncio.Queue], Any]], +) -> AsyncIterator[IterationState]: + """Run multiple learner coroutines in parallel and stream their IterationStates. + + Each callable in ``run_fns`` must accept an ``asyncio.Queue`` and put exactly + three kinds of tuples into it during its lifetime: + + * ``('state', IterationState)`` — for each iteration state to stream + * ``('error', Exception)`` — if the learner raises (before ``'done'``) + * ``('done', None)`` — exactly once, in a ``finally`` block, to signal completion + + This function manages queue creation, task scheduling, result streaming, and + exception propagation so that parallel learner implementations only need to + provide the learner-specific ``run_fn`` logic. + + Args: + run_fns: List of callables, one per parallel learner. Each callable takes + a shared ``asyncio.Queue`` and returns an awaitable coroutine. + + Yields: + IterationState objects in arrival order across all parallel learners. + + Raises: + Exception: The first exception raised by any learner, after all learners + have finished. + """ + queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() + tasks = [asyncio.create_task(fn(queue)) for fn in run_fns] + + completed = 0 + first_error: Exception | None = None + while completed < len(run_fns): + kind, value = await queue.get() + if kind == "done": + completed += 1 + elif kind == "state": + yield value + elif kind == "error": + if first_error is None: + first_error = value + + await asyncio.gather(*tasks, return_exceptions=True) + if first_error is not None: + raise first_error + + class TaskConfig(BaseModel): """Configuration for a single task. diff --git a/rose/rl/reinforcement_learner.py b/rose/rl/reinforcement_learner.py index e66cab9..e315dbb 100644 --- a/rose/rl/reinforcement_learner.py +++ b/rose/rl/reinforcement_learner.py @@ -1,14 +1,15 @@ import asyncio +import dataclasses import itertools import warnings -from collections.abc import AsyncIterator, Callable, Coroutine, Iterator +from collections.abc import AsyncIterator, Callable, Iterator from functools import wraps from typing import Any import typeguard from radical.asyncflow import WorkflowEngine -from rose.learner import IterationState, Learner, LearnerConfig +from rose.learner import IterationState, Learner, LearnerConfig, _stream_parallel class ReinforcementLearner(Learner): @@ -836,14 +837,17 @@ async def start( max_iter: int = 0, skip_pre_loop: bool = False, learner_configs: list[LearnerConfig | None] | None = None, - ) -> list[Any]: + ) -> AsyncIterator[IterationState]: """Run parallel reinforcement learning by launching multiple SequentialReinforcementLearners. Orchestrates multiple SequentialReinforcementLearner instances to - run concurrently, each with potentially different configurations. All - learners run independently and their results are collected when all - have completed. + run concurrently, each with potentially different configurations. States + are streamed in real time as each learner completes an iteration — use + ``async for`` to consume them. + + Each yielded ``IterationState`` includes a ``learner_id`` field (int) + indicating which parallel learner produced it. Args: parallel_learners: Number of parallel learners to run concurrently. @@ -857,15 +861,21 @@ async def start( If None, all learners use default configuration. Length must match parallel_learners if provided. - Returns: - list containing the final IterationState from each learner, - in the same order as the learners were launched. + Yields: + IterationState for each iteration of each learner, in arrival order. + Each state has ``learner_id`` set to the integer index of the learner. Raises: ValueError: If parallel_learners < 2. ValueError: If required base functions are not set. ValueError: If neither max_iter nor criterion_function is provided. ValueError: If learner_configs length doesn't match parallel_learners. + Exception: Re-raises any exception from a learner after all learners finish. + + Example:: + + async for state in rl.start(parallel_learners=3, max_iter=100): + print(f"Learner {state.learner_id}, iter {state.iteration}: {state.metric_value}") """ if parallel_learners < 2: raise ValueError("For single learner, use SequentialReinforcementLearner") @@ -884,60 +894,36 @@ async def start( print(f"Starting Parallel Reinforcement Learning with {parallel_learners} learners") - async def rl_learner_workflow(learner_id: int) -> Any: - """Run a single SequentialReinforcementLearner. - - Internal async function that manages the lifecycle of a single - SequentialReinforcementLearner within the parallel learning context. - - Args: - learner_id: Unique identifier for this learner instance. - - Returns: - The final IterationState from the sequential learner. - - Raises: - Exception: Re-raises any exception from the sequential learner - with additional context about which learner failed. - """ - try: - # Create and configure the sequential learner - sequential_learner: SequentialReinforcementLearner = ( - self._create_sequential_learner(learner_id, learner_configs[learner_id]) - ) - - # Convert parallel config to sequential config - sequential_config: LearnerConfig | None = self._convert_to_sequential_config( - learner_configs[learner_id] - ) - - # Run the sequential learner by iterating through start() - final_state = None - async for state in sequential_learner.start( - max_iter=max_iter, - skip_pre_loop=skip_pre_loop, - initial_config=sequential_config, - ): - final_state = state - # Let the learner run to completion - if self.is_stopped: - sequential_learner.stop() - - # Store metrics per learner - self.metric_values_per_iteration[f"learner-{learner_id}"] = ( - sequential_learner.metric_values_per_iteration - ) - - return final_state - except Exception as e: - print(f"RLLearner-{learner_id}] failed with error: {e}") - raise - - # Submit all learners asynchronously - futures: list[Coroutine] = [rl_learner_workflow(i) for i in range(parallel_learners)] - - # Wait for all learners to complete and collect results - return await asyncio.gather(*futures) + # Factory required: plain closure in a loop would capture the same variable reference. + def make_run_fn(learner_id: int): + async def run_learner(queue: asyncio.Queue) -> None: + try: + sequential_learner: SequentialReinforcementLearner = ( + self._create_sequential_learner(learner_id, learner_configs[learner_id]) + ) + async for state in sequential_learner.start( + max_iter=max_iter, + skip_pre_loop=skip_pre_loop, + initial_config=learner_configs[learner_id], + ): + if self.is_stopped: + sequential_learner.stop() + await queue.put( + ("state", dataclasses.replace(state, learner_id=learner_id)) + ) + self.metric_values_per_iteration[f"learner-{learner_id}"] = ( + sequential_learner.metric_values_per_iteration + ) + except Exception as e: + print(f"[RLLearner-{learner_id}] failed with error: {e}") + await queue.put(("error", e)) + finally: + await queue.put(("done", None)) + + return run_learner + + async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): + yield state async def learn( self, @@ -966,9 +952,12 @@ async def learn( DeprecationWarning, stacklevel=2, ) - return await self.start( + final_states: dict[int, IterationState | None] = {} + async for state in self.start( parallel_learners=parallel_learners, max_iter=max_iter, skip_pre_loop=skip_pre_loop, learner_configs=learner_configs, - ) + ): + final_states[state.learner_id] = state + return [final_states.get(i) for i in range(parallel_learners)] diff --git a/rose/uq/uq_active_learner.py b/rose/uq/uq_active_learner.py index 4e357a9..41141ba 100644 --- a/rose/uq/uq_active_learner.py +++ b/rose/uq/uq_active_learner.py @@ -1,5 +1,6 @@ import asyncio import copy +import dataclasses import itertools import warnings from collections.abc import AsyncIterator, Iterator @@ -9,7 +10,7 @@ from rose.uq.uq_learner import UQLearner, UQLearnerConfig -from ..learner import IterationState, TaskConfig +from ..learner import IterationState, TaskConfig, _stream_parallel class SeqUQLearner(UQLearner): @@ -467,12 +468,16 @@ async def start( max_iter: int = 0, skip_pre_loop: bool = False, learner_configs: dict[str, UQLearnerConfig | None] | None = None, - ) -> list[Any]: + ) -> AsyncIterator[IterationState]: """Run parallel UQ active learning by launching multiple SeqUQLearners. Orchestrates multiple SeqUQLearner instances to run concurrently, - each with potentially different configurations. All learners run - independently and their results are collected when all have completed. + each with potentially different configurations. States are streamed in + real time as each learner completes an iteration — use ``async for`` to + consume them. + + Each yielded ``IterationState`` includes a ``learner_id`` field (str) + set to the learner's name. Args: learner_names: list of learner names to run concurrently. @@ -487,14 +492,20 @@ async def start( If provided, the length must match the number of elements in learner_names. - Returns: - list containing the final IterationState from each learner, in the - same order as the learners were launched. + Yields: + IterationState for each iteration of each learner, in arrival order. + Each state has ``learner_id`` set to the learner's name (str). Raises: ValueError: If required base functions are not set. ValueError: If neither max_iter nor criterion_function is provided. ValueError: If learner_configs length doesn't match learner_names. + Exception: Re-raises any exception from a learner after all learners finish. + + Example:: + + async for state in learner.start(learner_names=["a", "b"], model_names=[...]): + print(f"Learner {state.learner_id}, iter {state.iteration}: {state.metric_value}") """ # Validate base functions are set if ( @@ -514,65 +525,40 @@ async def start( print(f"Starting Parallel UQ Active Learning with {len(learner_names)} learners") - async def _run_sequential_learner(learner_name: str) -> Any: - """Run a single SeqUQLearner. - - Internal async function that manages the lifecycle of a single - SeqUQLearner within the parallel learning context. - - Args: - learner_name: Unique identifier for this learner instance. - - Returns: - The final IterationState from the sequential learner. - - Raises: - Exception: Re-raises any exception from the sequential learner - with additional context about which learner failed. - """ - try: - # Create and configure the sequential learner - sequential_learner: SeqUQLearner = self._create_sequential_learner(learner_name) - - # Convert parallel config to sequential config - sequential_config: UQLearnerConfig | None = self._convert_to_sequential_config( - learner_configs[learner_name] - ) - print(f"[Parallel-Learner-{learner_name}] Starting sequential learning") - - # Run the sequential learner by iterating through start() - final_state = None - async for state in sequential_learner.start( - model_names=model_names, - num_predictions=num_predictions, - max_iter=max_iter, - skip_pre_loop=skip_pre_loop, - learning_config=sequential_config, - ): - final_state = state - if self.is_stopped: - sequential_learner.stop() - - # Book keep the iteration value from each learner - self.metric_values_per_iteration[f"learner-{learner_name}"] = ( - sequential_learner.metric_values_per_iteration - ) - self.uncertainty_values_per_iteration[f"learner-{learner_name}"] = ( - sequential_learner.uncertainty_values_per_iteration - ) - return final_state - - except Exception as e: - print(f"[Parallel-Learner-{learner_name}] Failed with error: {e}") - raise + # Factory required: plain closure in a loop would capture the same variable reference. + def make_run_fn(learner_name: str): + async def run_learner(queue: asyncio.Queue) -> None: + try: + sequential_learner: SeqUQLearner = self._create_sequential_learner(learner_name) + print(f"[Parallel-Learner-{learner_name}] Starting sequential learning") + async for state in sequential_learner.start( + model_names=model_names, + num_predictions=num_predictions, + max_iter=max_iter, + skip_pre_loop=skip_pre_loop, + learning_config=learner_configs[learner_name], + ): + if self.is_stopped: + sequential_learner.stop() + await queue.put( + ("state", dataclasses.replace(state, learner_id=learner_name)) + ) + self.metric_values_per_iteration[f"learner-{learner_name}"] = ( + sequential_learner.metric_values_per_iteration + ) + self.uncertainty_values_per_iteration[f"learner-{learner_name}"] = ( + sequential_learner.uncertainty_values_per_iteration + ) + except Exception as e: + print(f"[Parallel-Learner-{learner_name}] Failed with error: {e}") + await queue.put(("error", e)) + finally: + await queue.put(("done", None)) - # Submit all learners asynchronously - futures: list[Any] = [ - _run_sequential_learner(learner_name) for learner_name in learner_names - ] + return run_learner - # Wait for all learners to complete and collect results - return await asyncio.gather(*futures) + async for state in _stream_parallel([make_run_fn(name) for name in learner_names]): + yield state async def teach( self, @@ -598,22 +584,21 @@ async def teach( learner_configs: Configuration for each learner. Returns: - List of results from each learner (in old format for backward - compatibility). + List of final IterationState from each learner, in learner_names order. """ warnings.warn( "teach() is deprecated and will be removed in a future version. Use start() instead.", DeprecationWarning, stacklevel=2, ) - - # Call start() and return the final states directly - # The old teach() returned the final states from each learner - return await self.start( + final_states: dict[str, IterationState | None] = {} + async for state in self.start( learner_names=learner_names, model_names=model_names, num_predictions=num_predictions, max_iter=max_iter, skip_pre_loop=skip_pre_loop, learner_configs=learner_configs, - ) + ): + final_states[state.learner_id] = state + return [final_states.get(name) for name in learner_names] diff --git a/tests/integration/test_run_parallel_learner.py b/tests/integration/test_run_parallel_learner.py index 970fa0a..c180ccc 100644 --- a/tests/integration/test_run_parallel_learner.py +++ b/tests/integration/test_run_parallel_learner.py @@ -32,10 +32,19 @@ async def active_learn(sim, trained_model): async def check_mse(*args): return 0.05 # Return a metric value below threshold - await learner.start(parallel_learners=5, max_iter=2) + states = [] + async for state in learner.start(parallel_learners=5, max_iter=2): + states.append(state) - scores = learner.get_metric_results() + # Each learner stops after 1 iteration because criterion (0.05) < threshold (0.1) + assert len(states) > 0 + assert all(state.learner_id is not None for state in states) + assert {state.learner_id for state in states} == {0, 1, 2, 3, 4} + scores = learner.get_metric_results() assert scores != {} + # Verify per-learner metric keys are present + for i in range(5): + assert f"learner-{i}" in scores await learner.shutdown() diff --git a/tests/integration/test_run_rl_par_learner.py b/tests/integration/test_run_rl_par_learner.py index 941b0a5..ad24401 100644 --- a/tests/integration/test_run_rl_par_learner.py +++ b/tests/integration/test_run_rl_par_learner.py @@ -31,10 +31,17 @@ async def update(data, *args): async def check_reward(val, *args): return val > 2 - await rl.learn(parallel_learners=5, max_iter=2) + states = [] + async for state in rl.start(parallel_learners=5, max_iter=2): + states.append(state) - scores = rl.get_metric_results() + assert len(states) > 0 + assert all(state.learner_id is not None for state in states) + assert {state.learner_id for state in states} == {0, 1, 2, 3, 4} + scores = rl.get_metric_results() assert scores != {} + for i in range(5): + assert f"learner-{i}" in scores await rl.shutdown() diff --git a/tests/integration/test_run_rl_seq_learner.py b/tests/integration/test_run_rl_seq_learner.py index 31c03de..f61f449 100644 --- a/tests/integration/test_run_rl_seq_learner.py +++ b/tests/integration/test_run_rl_seq_learner.py @@ -31,10 +31,14 @@ async def update(data, *args): async def check_reward(val, *args): return val > 2 - await rl.learn(max_iter=1) + states = [] + async for state in rl.start(max_iter=1): + states.append(state) - scores = rl.get_metric_results() + assert len(states) == 1 + assert states[0].iteration == 0 + scores = rl.get_metric_results() assert scores != {} await rl.shutdown() diff --git a/tests/integration/test_run_uq_learner.py b/tests/integration/test_run_uq_learner.py index 74b3104..86533e3 100644 --- a/tests/integration/test_run_uq_learner.py +++ b/tests/integration/test_run_uq_learner.py @@ -51,16 +51,19 @@ async def check_uq(*args, **kwargs): # Calculate mean or just return a simple value for testing return 0.5 - results = await learner.start( + states = [] + async for state in learner.start( learner_names=["l1", "l2"], learner_configs={"l1": None, "l2": None}, model_names=["m1"], max_iter=2, - ) + ): + states.append(state) - # Verify we got results from both learners - assert len(results) == 2 - assert all(state is not None for state in results) + # criterion returns 0.05 < threshold 0.1, so each learner stops after 1 iteration → 2 states + assert len(states) == 2 + assert all(state is not None for state in states) + assert {state.learner_id for state in states} == {"l1", "l2"} scores = learner.get_metric_results() uq_scores = learner.get_uncertainty_results() diff --git a/tests/unit/test_learner_core.py b/tests/unit/test_learner_core.py new file mode 100644 index 0000000..1c21dc9 --- /dev/null +++ b/tests/unit/test_learner_core.py @@ -0,0 +1,459 @@ +"""Unit tests for core learner primitives: IterationState, LearnerConfig.get_task_config, +Learner base class state/callback machinery, and the _stream_parallel fan-in helper.""" + +import asyncio +import dataclasses +from unittest.mock import MagicMock + +import pytest +from radical.asyncflow import WorkflowEngine + +from rose.learner import ( + IterationState, + Learner, + LearnerConfig, + TaskConfig, + _stream_parallel, +) + +# --------------------------------------------------------------------------- +# IterationState +# --------------------------------------------------------------------------- + + +class TestIterationState: + def test_default_values(self): + state = IterationState(iteration=3) + assert state.iteration == 3 + assert state.metric_name is None + assert state.metric_value is None + assert state.metric_threshold is None + assert state.metric_history == [] + assert state.should_stop is False + assert state.current_config is None + assert state.learner_id is None + assert state.state == {} + + def test_attribute_access_to_state_dict(self): + state = IterationState(iteration=0, state={"loss": 0.5, "accuracy": 0.95}) + assert state.loss == 0.5 + assert state.accuracy == 0.95 + + def test_missing_state_key_returns_none(self): + state = IterationState(iteration=0) + assert state.nonexistent_key is None + + def test_get_with_existing_key(self): + state = IterationState(iteration=0, state={"loss": 0.5}) + assert state.get("loss") == 0.5 + + def test_get_with_missing_key_returns_default(self): + state = IterationState(iteration=0) + assert state.get("missing", "default_val") == "default_val" + + def test_to_dict_contains_all_top_level_fields(self): + state = IterationState( + iteration=2, + metric_name="mse", + metric_value=0.05, + metric_threshold=0.01, + should_stop=True, + learner_id=1, + ) + d = state.to_dict() + assert d["iteration"] == 2 + assert d["metric_name"] == "mse" + assert d["metric_value"] == 0.05 + assert d["metric_threshold"] == 0.01 + assert d["should_stop"] is True + assert d["learner_id"] == 1 + + def test_to_dict_merges_state_dict(self): + state = IterationState( + iteration=0, + state={"labeled_count": 100, "uncertainty": 0.3}, + ) + d = state.to_dict() + assert d["labeled_count"] == 100 + assert d["uncertainty"] == 0.3 + + def test_dataclasses_replace_sets_learner_id(self): + state = IterationState(iteration=5, metric_value=0.1, metric_name="mse") + replaced = dataclasses.replace(state, learner_id=3) + assert replaced.learner_id == 3 + + def test_dataclasses_replace_preserves_all_other_fields(self): + original_state = {"x": 42} + state = IterationState( + iteration=5, + metric_value=0.1, + metric_name="mse", + should_stop=False, + learner_id=None, + state=original_state, + ) + replaced = dataclasses.replace(state, learner_id=7) + assert replaced.iteration == 5 + assert replaced.metric_value == 0.1 + assert replaced.metric_name == "mse" + assert replaced.should_stop is False + assert replaced.state is original_state + + def test_learner_id_accepts_int(self): + state = IterationState(iteration=0, learner_id=0) + assert state.learner_id == 0 + + def test_learner_id_accepts_str(self): + state = IterationState(iteration=0, learner_id="learner-A") + assert state.learner_id == "learner-A" + + def test_learner_id_accepts_none(self): + state = IterationState(iteration=0, learner_id=None) + assert state.learner_id is None + + +# --------------------------------------------------------------------------- +# LearnerConfig.get_task_config +# --------------------------------------------------------------------------- + + +class TestLearnerConfigGetTaskConfig: + def test_returns_none_when_field_is_none(self): + config = LearnerConfig() + assert config.get_task_config("simulation", 0) is None + assert config.get_task_config("training", 5) is None + + def test_returns_taskconfig_directly_for_all_iterations(self): + tc = TaskConfig(kwargs={"--lr": "0.01"}) + config = LearnerConfig(training=tc) + assert config.get_task_config("training", 0) is tc + assert config.get_task_config("training", 5) is tc + assert config.get_task_config("training", 99) is tc + + def test_exact_iteration_match_in_dict(self): + tc_0 = TaskConfig(kwargs={"--n": "100"}) + tc_5 = TaskConfig(kwargs={"--n": "200"}) + config = LearnerConfig(simulation={0: tc_0, 5: tc_5, -1: TaskConfig()}) + assert config.get_task_config("simulation", 0) is tc_0 + assert config.get_task_config("simulation", 5) is tc_5 + + def test_falls_back_to_minus_one_key(self): + default_tc = TaskConfig(kwargs={"--n": "500"}) + config = LearnerConfig(simulation={0: TaskConfig(), -1: default_tc}) + assert config.get_task_config("simulation", 99) is default_tc + assert config.get_task_config("simulation", 1) is default_tc + + def test_returns_none_when_dict_has_no_match_and_no_default(self): + config = LearnerConfig(simulation={0: TaskConfig()}) + assert config.get_task_config("simulation", 7) is None + + def test_works_for_all_field_names(self): + tc = TaskConfig(kwargs={"k": "v"}) + for field in ( + "simulation", + "training", + "active_learn", + "environment", + "update", + "criterion", + ): + config = LearnerConfig(**{field: tc}) + assert config.get_task_config(field, 0) is tc + + +# --------------------------------------------------------------------------- +# Learner base class: state registry, callbacks, build_iteration_state +# --------------------------------------------------------------------------- + + +@pytest.fixture +def learner(): + mock_asyncflow = MagicMock(spec=WorkflowEngine) + return Learner(mock_asyncflow) + + +class TestLearnerStateRegistry: + def test_register_state_stores_value(self, learner): + learner.register_state("loss", 0.5) + assert learner.get_state("loss") == 0.5 + + def test_get_state_returns_default_when_missing(self, learner): + assert learner.get_state("nonexistent", "fallback") == "fallback" + + def test_get_all_state_returns_copy(self, learner): + learner.register_state("a", 1) + snapshot = learner.get_all_state() + snapshot["a"] = 999 # mutate the copy + assert learner.get_state("a") == 1 # original unchanged + + def test_clear_state_empties_registry(self, learner): + learner.register_state("a", 1) + learner.register_state("b", 2) + learner.clear_state() + assert learner.get_all_state() == {} + + def test_on_state_update_callback_invoked(self, learner): + calls = [] + + def cb(k, v): + calls.append((k, v)) + + learner.on_state_update(cb) + learner.register_state("x", 42) + assert calls == [("x", 42)] + + def test_multiple_callbacks_all_invoked(self, learner): + calls_a, calls_b = [], [] + + def cb_a(k, v): + calls_a.append((k, v)) + + def cb_b(k, v): + calls_b.append((k, v)) + + learner.on_state_update(cb_a) + learner.on_state_update(cb_b) + learner.register_state("y", 7) + assert calls_a == [("y", 7)] + assert calls_b == [("y", 7)] + + def test_callback_error_does_not_break_register_state(self, learner): + def bad_callback(k, v): + raise RuntimeError("boom") + + learner.on_state_update(bad_callback) + # Should not raise + learner.register_state("z", 99) + assert learner.get_state("z") == 99 + + def test_remove_state_callback(self, learner): + calls = [] + + def cb(k, v): + calls.append((k, v)) + + learner.on_state_update(cb) + learner.remove_state_callback(cb) + learner.register_state("a", 1) + assert calls == [] + + +class TestExtractStateFromResult: + def test_dict_result_registers_all_keys(self, learner): + learner._extract_state_from_result({"loss": 0.1, "acc": 0.9}) + assert learner.get_state("loss") == 0.1 + assert learner.get_state("acc") == 0.9 + + def test_non_dict_result_does_nothing(self, learner): + learner._extract_state_from_result("some_string") + learner._extract_state_from_result(42) + learner._extract_state_from_result(None) + assert learner.get_all_state() == {} + + def test_excluded_keys_are_skipped(self, learner): + learner._extract_state_from_result( + {"loss": 0.1, "metric_value": 0.05, "should_stop": True}, + exclude_keys={"metric_value", "should_stop"}, + ) + assert learner.get_state("loss") == 0.1 + assert learner.get_state("metric_value") is None + assert learner.get_state("should_stop") is None + + +class TestBuildIterationState: + def test_builds_state_with_metric_info_from_criterion(self, learner): + learner.criterion_function = { + "metric_name": "mse", + "threshold": 0.01, + } + state = learner.build_iteration_state(iteration=3, metric_value=0.05, should_stop=False) + assert state.iteration == 3 + assert state.metric_name == "mse" + assert state.metric_threshold == 0.01 + assert state.metric_value == 0.05 + assert state.should_stop is False + + def test_builds_state_with_registered_state(self, learner): + learner.register_state("labeled_count", 200) + state = learner.build_iteration_state(iteration=0) + assert state.state["labeled_count"] == 200 + assert state.labeled_count == 200 # attribute-style access + + def test_metric_history_reflects_recorded_values(self, learner): + learner.metric_values_per_iteration = {0: 0.5, 1: 0.3} + state = learner.build_iteration_state(iteration=2, metric_value=0.1) + assert state.metric_history == [0.5, 0.3] + + def test_current_config_stored_in_state(self, learner): + cfg = LearnerConfig(training=TaskConfig(kwargs={"--lr": "0.001"})) + state = learner.build_iteration_state(iteration=0, current_config=cfg) + assert state.current_config is cfg + + def test_no_criterion_function_yields_none_metric_info(self, learner): + learner.criterion_function = {} + state = learner.build_iteration_state(iteration=0) + assert state.metric_name is None + assert state.metric_threshold is None + + +# --------------------------------------------------------------------------- +# compare_metric +# --------------------------------------------------------------------------- + + +class TestCompareMetric: + def test_less_than(self, learner): + from rose.metrics import MEAN_SQUARED_ERROR_MSE + + assert learner.compare_metric(MEAN_SQUARED_ERROR_MSE, 0.005, 0.01) is True + assert learner.compare_metric(MEAN_SQUARED_ERROR_MSE, 0.02, 0.01) is False + + def test_greater_than_custom_operator(self, learner): + assert learner.compare_metric("MY_METRIC", 5.0, 3.0, operator=">") is True + assert learner.compare_metric("MY_METRIC", 1.0, 3.0, operator=">") is False + + def test_equal_operator(self, learner): + assert learner.compare_metric("MY_METRIC", 1.0, 1.0, operator="==") is True + assert learner.compare_metric("MY_METRIC", 1.1, 1.0, operator="==") is False + + def test_custom_metric_without_operator_raises(self, learner): + with pytest.raises(ValueError, match="Operator value must be provided"): + learner.compare_metric("UNKNOWN_METRIC", 0.5, 0.1) + + def test_unknown_operator_raises(self, learner): + with pytest.raises(ValueError, match="Unknown comparison operator"): + learner.compare_metric("MY_METRIC", 0.5, 0.1, operator="!=") + + +# --------------------------------------------------------------------------- +# _stream_parallel +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestStreamParallel: + async def test_empty_run_fns_completes_immediately(self): + results = [] + async for state in _stream_parallel([]): + results.append(state) + assert results == [] + + async def test_single_learner_streams_all_states(self): + state_a = IterationState(iteration=0) + state_b = IterationState(iteration=1) + + async def run(queue: asyncio.Queue) -> None: + try: + await queue.put(("state", state_a)) + await queue.put(("state", state_b)) + finally: + await queue.put(("done", None)) + + results = [] + async for s in _stream_parallel([run]): + results.append(s) + + assert results == [state_a, state_b] + + async def test_two_learners_all_states_yielded(self): + s0 = IterationState(iteration=0, learner_id=0) + s1 = IterationState(iteration=0, learner_id=1) + + async def run_0(queue: asyncio.Queue) -> None: + try: + await queue.put(("state", s0)) + finally: + await queue.put(("done", None)) + + async def run_1(queue: asyncio.Queue) -> None: + try: + await queue.put(("state", s1)) + finally: + await queue.put(("done", None)) + + results = [] + async for s in _stream_parallel([run_0, run_1]): + results.append(s) + + assert len(results) == 2 + assert s0 in results + assert s1 in results + + async def test_done_count_terminates_loop(self): + """All N 'done' signals must arrive before _stream_parallel exits.""" + barrier = asyncio.Event() + + async def slow_run(queue: asyncio.Queue) -> None: + await barrier.wait() + try: + await queue.put(("state", IterationState(iteration=0))) + finally: + await queue.put(("done", None)) + + async def fast_run(queue: asyncio.Queue) -> None: + try: + await queue.put(("state", IterationState(iteration=0))) + finally: + await queue.put(("done", None)) + barrier.set() + + results = [] + async for s in _stream_parallel([slow_run, fast_run]): + results.append(s) + + assert len(results) == 2 + + async def test_error_is_propagated_after_all_done(self): + exc = RuntimeError("learner exploded") + + async def failing_run(queue: asyncio.Queue) -> None: + try: + await queue.put(("error", exc)) + finally: + await queue.put(("done", None)) + + with pytest.raises(RuntimeError, match="learner exploded"): + async for _ in _stream_parallel([failing_run]): + pass + + async def test_error_from_one_does_not_suppress_states_from_other(self): + good_state = IterationState(iteration=0, learner_id=0) + + async def good_run(queue: asyncio.Queue) -> None: + try: + await queue.put(("state", good_state)) + finally: + await queue.put(("done", None)) + + async def bad_run(queue: asyncio.Queue) -> None: + try: + await queue.put(("error", ValueError("oops"))) + finally: + await queue.put(("done", None)) + + results = [] + with pytest.raises(ValueError, match="oops"): + async for s in _stream_parallel([good_run, bad_run]): + results.append(s) + + # Good learner's state was streamed before exception re-raised + assert good_state in results + + async def test_only_first_error_is_raised(self): + """When two learners both fail, only the first error is re-raised.""" + + async def fail_a(queue: asyncio.Queue) -> None: + try: + await queue.put(("error", ValueError("error-A"))) + finally: + await queue.put(("done", None)) + + async def fail_b(queue: asyncio.Queue) -> None: + try: + await queue.put(("error", ValueError("error-B"))) + finally: + await queue.put(("done", None)) + + with pytest.raises(ValueError): + async for _ in _stream_parallel([fail_a, fail_b]): + pass diff --git a/tests/unit/test_learner_stop.py b/tests/unit/test_learner_stop.py index 07f1df2..9c18ef7 100644 --- a/tests/unit/test_learner_stop.py +++ b/tests/unit/test_learner_stop.py @@ -1,10 +1,11 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from radical.asyncflow import WorkflowEngine -from rose.al.active_learner import SequentialActiveLearner +from rose.al.active_learner import ParallelActiveLearner, SequentialActiveLearner +from rose.learner import IterationState @pytest.mark.asyncio @@ -66,3 +67,41 @@ async def run_learner(): assert count == 1 assert learner.is_stopped + + +@pytest.mark.asyncio +async def test_parallel_learner_stop_terminates_stream(): + """Test that stop() on a ParallelActiveLearner causes the async-for loop to exit.""" + mock_asyncflow = MagicMock(spec=WorkflowEngine) + learner = ParallelActiveLearner(mock_asyncflow) + + learner.simulation_function = AsyncMock(return_value="sim") + learner.training_function = AsyncMock(return_value="train") + learner.active_learn_function = AsyncMock(return_value="acl") + learner.criterion_function = AsyncMock(return_value=False) + + # Each mock sequential learner yields many states so we can verify early stop + async def mock_sequential_start(*args, **kwargs): + for i in range(50): + yield IterationState(iteration=i, should_stop=False) + + mock_seq = MagicMock() + mock_seq.start = mock_sequential_start + mock_seq.metric_values_per_iteration = {} + + count = 0 + + async def run_test(): + nonlocal count + with patch.object(learner, "_create_sequential_learner", return_value=mock_seq): + async for _state in learner.start(parallel_learners=2, max_iter=50): + count += 1 + if count == 1: + learner.stop() + + try: + await asyncio.wait_for(run_test(), timeout=5.0) + except asyncio.TimeoutError: + pytest.fail("ParallelActiveLearner did not terminate after stop() within timeout") + + assert learner.is_stopped diff --git a/tests/unit/test_parallel_learner.py b/tests/unit/test_parallel_learner.py index bd50ded..3bc7c52 100644 --- a/tests/unit/test_parallel_learner.py +++ b/tests/unit/test_parallel_learner.py @@ -88,7 +88,8 @@ async def test_start_validation_errors(self, parallel_learner): parallel_learner.criterion_function = None with pytest.raises(ValueError, match="For single learner, use SequentialActiveLearner"): - await parallel_learner.start(parallel_learners=1) + async for _ in parallel_learner.start(parallel_learners=1): + pass # Test with missing simulation functions it should raise error about # simulation first @@ -96,7 +97,8 @@ async def test_start_validation_errors(self, parallel_learner): ValueError, match="Simulation function must be set when not using simulation pool!", ): - await parallel_learner.start(parallel_learners=2, max_iter=1) + async for _ in parallel_learner.start(parallel_learners=2, max_iter=1): + pass # Test with missing simulation functions and skip_simulation_step # it should raise an error about missing train/active_learn tasks @@ -104,7 +106,10 @@ async def test_start_validation_errors(self, parallel_learner): ValueError, match="Training and Active Learning functions must be set!", ): - await parallel_learner.start(parallel_learners=2, max_iter=1, skip_simulation_step=True) + async for _ in parallel_learner.start( + parallel_learners=2, max_iter=1, skip_simulation_step=True + ): + pass # Set functions but test missing stop criteria parallel_learner.simulation_function = AsyncMock() @@ -115,16 +120,18 @@ async def test_start_validation_errors(self, parallel_learner): Exception, match="Either max_iter > 0 or criterion_function must be provided.", ): - await parallel_learner.start(parallel_learners=2, max_iter=0) + async for _ in parallel_learner.start(parallel_learners=2, max_iter=0): + pass # Test learner_configs length mismatch parallel_learner.criterion_function = AsyncMock() learner_configs = [None] # Only 1 config for 2 learners with pytest.raises(ValueError, match="learner_configs length must match parallel_learners"): - await parallel_learner.start( + async for _ in parallel_learner.start( parallel_learners=2, max_iter=1, learner_configs=learner_configs - ) + ): + pass @pytest.mark.asyncio async def test_start_successful_parallel_execution(self, configured_parallel_learner): @@ -143,12 +150,16 @@ async def mock_start(*args, **kwargs): "_create_sequential_learner", return_value=mock_sequential, ): - results = await configured_parallel_learner.start(parallel_learners=2, max_iter=1) + states = [] + async for state in configured_parallel_learner.start(parallel_learners=2, max_iter=1): + states.append(state) - # Verify results - assert len(results) == 2 + # Each learner yields one state, so 2 states total + assert len(states) == 2 # Results are IterationState objects - assert all(isinstance(r, IterationState) for r in results) + assert all(isinstance(s, IterationState) for s in states) + # Each state has learner_id set to the learner index + assert {s.learner_id for s in states} == {0, 1} # Verify metric collection assert "learner-0" in configured_parallel_learner.metric_values_per_iteration @@ -188,7 +199,10 @@ async def fail_start(*args, **kwargs): with patch("builtins.print") as mock_print: # Should raise exception due to learner failure with pytest.raises(Exception, match="Learner failed"): - await configured_parallel_learner.start(parallel_learners=2, max_iter=1) + async for _ in configured_parallel_learner.start( + parallel_learners=2, max_iter=1 + ): + pass # Verify error was printed (learner 1 fails, not 0) - mock_print.assert_any_call("ActiveLearner-1] failed with error: Learner failed") + mock_print.assert_any_call("[ActiveLearner-1] failed with error: Learner failed") diff --git a/tests/unit/test_rl_par_learner.py b/tests/unit/test_rl_par_learner.py index 5b1baf2..d197874 100644 --- a/tests/unit/test_rl_par_learner.py +++ b/tests/unit/test_rl_par_learner.py @@ -85,7 +85,8 @@ async def test_start_missing_environment_function(self, parallel_learner): parallel_learner.update_function = AsyncMock() with pytest.raises(ValueError, match="Environment and Update functions"): - await parallel_learner.start(parallel_learners=2, max_iter=1) + async for _ in parallel_learner.start(parallel_learners=2, max_iter=1): + pass @pytest.mark.asyncio async def test_start_missing_update_function(self, parallel_learner): @@ -94,13 +95,15 @@ async def test_start_missing_update_function(self, parallel_learner): parallel_learner.update_function = None with pytest.raises(ValueError, match="Environment and Update functions"): - await parallel_learner.start(parallel_learners=2, max_iter=1) + async for _ in parallel_learner.start(parallel_learners=2, max_iter=1): + pass @pytest.mark.asyncio async def test_start_invalid_parallel_learners_count(self, configured_parallel_learner): """Test that start raises exception when parallel_learners < 2.""" with pytest.raises(ValueError) as excinfo: - await configured_parallel_learner.start(parallel_learners=1, max_iter=1) + async for _ in configured_parallel_learner.start(parallel_learners=1, max_iter=1): + pass assert "For single learner, use SequentialReinforcementLearner" in str(excinfo.value) @@ -111,7 +114,8 @@ async def test_start_without_iterations_or_criterion(self, configured_parallel_l configured_parallel_learner.criterion_function = None with pytest.raises(ValueError, match="Either max_iter > 0 or criterion"): - await configured_parallel_learner.start(parallel_learners=2) + async for _ in configured_parallel_learner.start(parallel_learners=2): + pass @pytest.mark.asyncio async def test_start_mismatched_config_length(self, configured_parallel_learner): @@ -120,11 +124,12 @@ async def test_start_mismatched_config_length(self, configured_parallel_learner) learner_configs = [LearnerConfig(), LearnerConfig()] # Length 2 with pytest.raises(ValueError) as excinfo: - await configured_parallel_learner.start( + async for _ in configured_parallel_learner.start( parallel_learners=3, # Different length max_iter=1, learner_configs=learner_configs, - ) + ): + pass assert "learner_configs length must match parallel_learners" in str(excinfo.value) @@ -145,11 +150,16 @@ async def mock_start(*args, **kwargs): "_create_sequential_learner", return_value=mock_sequential_learner, ): - results = await configured_parallel_learner.start(parallel_learners=2, max_iter=1) + states = [] + async for state in configured_parallel_learner.start(parallel_learners=2, max_iter=1): + states.append(state) - assert len(results) == 2 + # Each learner yields one state, so 2 states total + assert len(states) == 2 # Results are IterationState objects - assert all(isinstance(r, IterationState) for r in results) + assert all(isinstance(s, IterationState) for s in states) + # Each state has learner_id set to the learner index + assert {s.learner_id for s in states} == {0, 1} # Verify metric storage assert "learner-0" in configured_parallel_learner.metric_values_per_iteration @@ -187,7 +197,8 @@ async def fail_start(*args, **kwargs): ): # The exception should propagate up and be raised with pytest.raises(Exception) as excinfo: - await configured_parallel_learner.start(parallel_learners=2, max_iter=1) + async for _ in configured_parallel_learner.start(parallel_learners=2, max_iter=1): + pass assert "Learner failed" in str(excinfo.value) @@ -209,9 +220,10 @@ async def mock_start(*args, **kwargs): "_create_sequential_learner", return_value=mock_sequential_learner, ): - await configured_parallel_learner.start( + async for _ in configured_parallel_learner.start( parallel_learners=2, max_iter=1, skip_pre_loop=True - ) + ): + pass # Verify that sequential learners were called with skip_pre_loop=True assert len(start_calls) == 2 diff --git a/tests/unit/test_sequential_learner.py b/tests/unit/test_sequential_learner.py index 43f2c15..4956cc5 100644 --- a/tests/unit/test_sequential_learner.py +++ b/tests/unit/test_sequential_learner.py @@ -191,3 +191,41 @@ async def test_set_next_config(self, configured_learner): # Should be stored as pending assert configured_learner._pending_config == new_config + + @pytest.mark.asyncio + async def test_set_next_config_takes_effect_in_next_iteration(self, configured_learner): + """Test that a config set via set_next_config is consumed in the following iteration.""" + from rose.learner import LearnerConfig, TaskConfig + + new_config = LearnerConfig(training=TaskConfig(kwargs={"--lr": "0.0001"})) + + states = [] + iteration = 0 + async for state in configured_learner.start(max_iter=2, skip_pre_loop=True): + states.append(state) + if iteration == 0: + configured_learner.set_next_config(new_config) + iteration += 1 + + assert len(states) == 2 + # First iteration uses the initial config (None) + assert states[0].current_config is None + # Second iteration uses the new config set after the first yield + assert states[1].current_config is new_config + + @pytest.mark.asyncio + async def test_skip_simulation_step_does_not_register_simulation(self, configured_learner): + """Test that skip_simulation_step=True skips registering simulation tasks.""" + configured_learner._check_stop_criterion.return_value = (True, 0.01) + + async for _ in configured_learner.start( + max_iter=0, skip_pre_loop=True, skip_simulation_step=True + ): + pass + + # Every _register_task call should NOT have used simulation_function + for call in configured_learner._register_task.call_args_list: + args, _ = call + if args: + task_obj = args[0] + assert task_obj is not configured_learner.simulation_function diff --git a/tests/unit/test_uq_learner.py b/tests/unit/test_uq_learner.py index f75d04e..d02398a 100644 --- a/tests/unit/test_uq_learner.py +++ b/tests/unit/test_uq_learner.py @@ -14,9 +14,8 @@ async def mock_start_iterator(*args, **kwargs): """Helper to mock an async iterator.""" - # Create a mock IterationState - state = MagicMock(spec=IterationState) - state.to_dict.return_value = "learner_result" + # Yield a real IterationState so dataclasses.replace works in ParallelUQLearner + state = IterationState(iteration=0, should_stop=True) yield state @@ -111,12 +110,13 @@ async def test_teach_validation_errors(self, parallel_learner): Exception, match="Simulation, Training, and Active Learning functions must be set!", ): - await parallel_learner.start( + async for _ in parallel_learner.start( learner_names=["l1", "l2"], learner_configs={"l1": None, "l2": None}, model_names=["m1"], max_iter=1, - ) + ): + pass # Set functions but test missing stop criteria parallel_learner.simulation_function = AsyncMock() @@ -130,18 +130,22 @@ async def test_teach_validation_errors(self, parallel_learner): Exception, match="learner_configs length must match learner_names", ): - await parallel_learner.start( + async for _ in parallel_learner.start( learner_names=["l1", "l2"], learner_configs={"l1": None}, model_names=["m1"], max_iter=1, - ) + ): + pass with pytest.raises( Exception, match="Either max_iter or stop_criterion_function must be provided.", ): - await parallel_learner.start(learner_names=["l1", "l2"], model_names=["m1"], max_iter=0) + async for _ in parallel_learner.start( + learner_names=["l1", "l2"], model_names=["m1"], max_iter=0 + ): + pass with pytest.raises( Exception, @@ -234,41 +238,24 @@ async def test_teach_successful_parallel_execution(self, configured_parallel_lea "_create_sequential_learner", return_value=mock_sequential, ): - with patch.object( - configured_parallel_learner, - "_convert_to_sequential_config", - return_value=None, + states = [] + async for state in configured_parallel_learner.start( + learner_names=["l1", "l2"], + learner_configs={"l1": None, "l2": None}, + model_names=["m1"], + max_iter=1, ): - results = await configured_parallel_learner.start( - learner_names=["l1", "l2"], - learner_configs={"l1": None, "l2": None}, - model_names=["m1"], - max_iter=1, - ) - - # Verify results - assert len(results) == 2 - assert all(result.to_dict() == "learner_result" for result in results) + states.append(state) - # Verify sequential learners were called - # We can't easily check call count for a generator function mock - # but results verify it was called. - print( - "metric_values_per_iteration", - configured_parallel_learner.metric_values_per_iteration, - ) - print( - "uncertainty_values_per_iteration", - configured_parallel_learner.uncertainty_values_per_iteration, - ) + # Each learner yields one state, so 2 states total + assert len(states) == 2 + assert all(isinstance(s, IterationState) for s in states) + # Each state has learner_id set to the learner name + assert {s.learner_id for s in states} == {"l1", "l2"} - # Verify metric collection - assert ( - "learner-l1" in configured_parallel_learner.metric_values_per_iteration.keys() - ) - assert ( - "learner-l2" in configured_parallel_learner.metric_values_per_iteration.keys() - ) + # Verify metric collection + assert "learner-l1" in configured_parallel_learner.metric_values_per_iteration + assert "learner-l2" in configured_parallel_learner.metric_values_per_iteration @pytest.mark.asyncio async def test_teach_learner_failure_handling(self, configured_parallel_learner): @@ -287,23 +274,19 @@ async def failing_start(*args, **kwargs): "_create_sequential_learner", return_value=mock_sequential, ): - with patch.object( - configured_parallel_learner, - "_convert_to_sequential_config", - return_value=None, - ): - # Mock print to capture error message - with patch("builtins.print") as mock_print: - # Should raise exception due to learner failure - with pytest.raises(Exception, match="Learner failed"): - await configured_parallel_learner.start( - learner_names=["l1"], - model_names=["m1"], - learner_configs={"l1": None}, - max_iter=1, - ) - - # Verify error was printed - mock_print.assert_any_call( - "[Parallel-Learner-l1] Failed with error: Learner failed" - ) + # Mock print to capture error message + with patch("builtins.print") as mock_print: + # Should raise exception due to learner failure + with pytest.raises(Exception, match="Learner failed"): + async for _ in configured_parallel_learner.start( + learner_names=["l1"], + model_names=["m1"], + learner_configs={"l1": None}, + max_iter=1, + ): + pass + + # Verify error was printed + mock_print.assert_any_call( + "[Parallel-Learner-l1] Failed with error: Learner failed" + )