Skip to content

Parallel Learner API emits async results#85

Merged
AymenFJA merged 5 commits intomainfrom
feature/iter_parallel_learner
Mar 5, 2026
Merged

Parallel Learner API emits async results#85
AymenFJA merged 5 commits intomainfrom
feature/iter_parallel_learner

Conversation

@AymenFJA
Copy link
Collaborator

1-Make the ParallelLearner API emit real-time results per iteration, like other learners
2-Update tests
3-Update examples

1-Make the ParallelLearner API emits real time results per iteration like other learners
2-Update tests
3-Update examples
@AymenFJA AymenFJA self-assigned this Feb 27, 2026
@AymenFJA AymenFJA added enhancement New feature or request raas ROSE As A Service Q1 labels Feb 27, 2026
@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the parallel learning APIs by transforming their start methods into asynchronous iterators. This change allows users to receive and process iteration results from multiple parallel learners in real-time, rather than waiting for all learners to complete. This improves the responsiveness and observability of long-running parallel learning experiments, providing immediate feedback on the progress of each individual learner.

Highlights

  • Real-time Iteration Results: The ParallelActiveLearner.start, ParallelReinforcementLearner.start, and ParallelUQLearner.start methods now return an AsyncIterator that yields IterationState objects in real-time as each parallel learner completes an iteration. This allows for immediate consumption and display of progress.
  • IterationState Enhancement: The IterationState dataclass has been extended with a new learner_id field, enabling identification of which parallel learner produced a specific iteration state.
  • API Usage Update: Examples and unit tests have been updated to reflect the new async for syntax for consuming results from the parallel learner start methods, demonstrating the real-time streaming capability.
  • Deprecated teach Method Update: The deprecated teach methods in ParallelActiveLearner, ParallelReinforcementLearner, and ParallelUQLearner have been updated internally to use the new async for iteration over the start method, ensuring continued functionality while signaling their eventual removal.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/active_learn/parallel/run_me_per_learner_config.py
    • Updated the MSE threshold for the stop criterion from 0.1 to 0.01.
    • Modified the al.start call to use async for to process real-time iteration states and print learner progress.
    • Added max_iter parameter to al.start call.
  • examples/active_learn/parallel/run_me_per_learner_per_iter_config.py
    • Replaced RadicalExecutionBackend with ConcurrentExecutionBackend for the workflow engine.
    • Modified the al.start call to use async for to process real-time iteration states and print learner progress.
  • examples/active_learn/parallel/run_me_with_dynamic_config.py
    • Modified the al.start call to use async for to process real-time iteration states and print learner progress.
  • examples/active_learn/uq/run_me.py
    • Modified the learner.start call to use async for to collect real-time iteration states.
    • Updated the final results collection to store IterationState objects in a dictionary keyed by learner ID and then convert them to dictionaries for JSON output.
  • rose/al/active_learner.py
    • Imported the dataclasses module.
    • Changed the return type of ParallelActiveLearner.start from list[Any] to AsyncIterator[IterationState].
    • Updated docstrings for ParallelActiveLearner.start to reflect the new async iterator behavior, including yield description and an example.
    • Refactored the internal active_learner_workflow into run_learner and implemented an asyncio.Queue to emit states in real-time.
    • Modified the deprecated teach method to internally use async for over self.start() to collect and return final states.
  • rose/learner.py
    • Added a new learner_id field to the IterationState dataclass.
    • Included the learner_id in the dictionary representation returned by the to_dict method of IterationState.
  • rose/rl/reinforcement_learner.py
    • Imported the dataclasses module.
    • Changed the return type of ParallelReinforcementLearner.start from list[Any] to AsyncIterator[IterationState].
    • Updated docstrings for ParallelReinforcementLearner.start to reflect the new async iterator behavior, including yield description and an example.
    • Refactored the internal rl_learner_workflow into run_learner and implemented an asyncio.Queue to emit states in real-time.
    • Modified the deprecated learn method to internally use async for over self.start() to collect and return final states.
  • rose/uq/uq_active_learner.py
    • Imported the dataclasses module.
    • Changed the return type of ParallelUQLearner.start from list[Any] to AsyncIterator[IterationState].
    • Updated docstrings for ParallelUQLearner.start to reflect the new async iterator behavior, including yield description and an example.
    • Refactored the internal _run_sequential_learner into run_learner and implemented an asyncio.Queue to emit states in real-time.
    • Modified the deprecated teach method to internally use async for over self.start() to collect and return final states.
  • tests/integration/test_run_parallel_learner.py
    • Updated the call to learner.start to use async for to consume the yielded states.
  • tests/integration/test_run_uq_learner.py
    • Updated the call to learner.start to use async for to collect states.
    • Modified assertions to verify the number of collected states and the presence of learner_id in the states.
  • tests/unit/test_parallel_learner.py
    • Updated all calls to parallel_learner.start within tests to use async for _ in ... syntax.
    • Adjusted assertions to check for the correct number of yielded states and the learner_id attribute.
  • tests/unit/test_rl_par_learner.py
    • Updated all calls to parallel_learner.start within tests to use async for _ in ... syntax.
    • Adjusted assertions to check for the correct number of yielded states and the learner_id attribute.
  • tests/unit/test_uq_learner.py
    • Modified mock_start_iterator to yield a concrete IterationState instance instead of a MagicMock, ensuring compatibility with dataclasses.replace.
    • Updated all calls to parallel_learner.start within tests to use async for _ in ... syntax.
    • Adjusted assertions to check for the correct number of yielded states and the learner_id attribute.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the parallel learner APIs (ParallelActiveLearner, ParallelReinforcementLearner, ParallelUQLearner) to emit real-time results per iteration using an async generator pattern. This is a significant improvement for monitoring and interacting with parallel learning processes. The implementation correctly uses asyncio.Queue for streaming states and handles errors robustly. The examples and tests have been updated accordingly.

My review includes two main points:

  1. A suggestion to refactor the duplicated parallel execution logic into a shared utility to improve maintainability.
  2. An improvement to a test assertion to make it more precise.

Overall, the changes are well-implemented and significantly improve the library's capabilities.

Comment on lines 471 to 517
queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()

Raises:
Exception: Re-raises any exception from the sequential learner
with additional context about which learner failed.
"""
async def run_learner(learner_id: int) -> None:
try:
# Create and configure the sequential learner
sequential_learner: SequentialActiveLearner = self._create_sequential_learner(
learner_id, learner_configs[learner_id]
)

# Convert parallel config to sequential config
sequential_config: LearnerConfig | None = self._convert_to_sequential_config(
learner_configs[learner_id]
)

# Run the sequential learner by iterating through start()
final_state = None
async for state in sequential_learner.start(
max_iter=max_iter,
skip_pre_loop=skip_pre_loop,
skip_simulation_step=skip_simulation_step,
initial_config=sequential_config,
):
final_state = state
if self.is_stopped:
sequential_learner.stop()
await queue.put(("state", dataclasses.replace(state, learner_id=learner_id)))

# book keep the iteration value from each learner
self.metric_values_per_iteration[f"learner-{learner_id}"] = (
sequential_learner.metric_values_per_iteration
)

return final_state
except Exception as e:
print(f"ActiveLearner-{learner_id}] failed with error: {e}")
raise

print(f"Starting Parallel Active Learning with {parallel_learners} learners")

# Submit all learners asynchronously
learners: list[Coroutine] = [active_learner_workflow(i) for i in range(parallel_learners)]

# Wait for all learners to complete and collect results
return await asyncio.gather(*learners)
await queue.put(("error", e))
finally:
await queue.put(("done", None))

tasks = [asyncio.create_task(run_learner(i)) for i in range(parallel_learners)]

completed = 0
first_error: Exception | None = None
while completed < parallel_learners:
kind, value = await queue.get()
if kind == "done":
completed += 1
elif kind == "state":
yield value
elif kind == "error":
if first_error is None:
first_error = value

await asyncio.gather(*tasks, return_exceptions=True)
if first_error is not None:
raise first_error

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The core logic for running parallel learners by streaming results via an asyncio.Queue is a great pattern. However, this logic is duplicated across ParallelActiveLearner, ParallelReinforcementLearner, and ParallelUQLearner. This duplication increases maintenance overhead, as any future bug fixes or enhancements to this streaming mechanism will need to be applied in three different places.

To improve maintainability, I suggest refactoring this common logic into a shared helper function or a method in a base class. This centralizes the parallel execution pattern, making the code cleaner and easier to manage.

Here's a conceptual example of how such a helper could be structured:

# In a shared utility module or a base class
async def _run_parallel_stream(
    learner_runners: list[Coroutine]
) -> AsyncIterator[IterationState]:
    """
    Runs multiple learner coroutines in parallel and streams their states.
    Each coroutine is expected to be a wrapper that runs a learner and
    puts ('state', state), ('error', e), and ('done', None) tuples
    into a shared queue.
    """
    # ... implementation with queue, task creation, and result streaming ...

Each start method would then be responsible for preparing its specific run_learner coroutines and passing them to this shared helper. This would significantly reduce code duplication.

Comment on lines +63 to +64
# Verify we got states from both learners (2 iterations each = 4 states)
assert len(states) > 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 63 // Verify we got states from both learners (2 iterations each = 4 states) seems incorrect given the test setup. The stop criterion check_mse returns 0.05, which is less than the threshold of 0.1. This will cause each of the two learners to stop after the first iteration. Therefore, we should expect exactly 2 states in total.

The assertion on line 64, assert len(states) > 0, is also weaker than it could be. A more precise assertion would verify the exact number of expected states.

I suggest updating the comment and strengthening the assertion to reflect the expected behavior.

Suggested change
# Verify we got states from both learners (2 iterations each = 4 states)
assert len(states) > 0
# Verify we got states from both learners (1 iteration each = 2 states)
assert len(states) == 2

AymenFJA added 3 commits March 5, 2026 17:13
…learners

- Pass learner_configs[learner_id] directly to start(), removing the sequential_config intermediate variable
- Add factory comment explaining the closure-in-loop pitfall
- Fix print typo: ActiveLearner-{id}] → [ActiveLearner-{id}]
- Apply the same changes to reinforcement_learner
- Remove sequential_config and add factory comment in uq_active_learner
- Update unit test assertion to match corrected print format
@AymenFJA
Copy link
Collaborator Author

AymenFJA commented Mar 5, 2026

@gemini-code-assist review

@AymenFJA AymenFJA added documentation Improvements or additions to documentation High labels Mar 5, 2026
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the parallel learners to use an asynchronous streaming API, aligning them with sequential learners and improving real-time result handling. While this improves the API and promotes code reuse with the _stream_parallel helper, it introduces a potential resource exhaustion (DoS) vulnerability due to an unbounded queue and task creation in rose/learner.py. Furthermore, the parallel learner API passes unsanitized configuration arguments to task functions in rose/al/, rose/rl/, and rose/uq/, which could lead to command injection if the configuration is from an untrusted source. It is recommended to implement backpressure and task limits in the streaming logic, add robust sanitization or clear security warnings for task argument handling, and improve the consistency of error logging across different learner implementations.

max_iter=max_iter,
skip_pre_loop=skip_pre_loop,
skip_simulation_step=skip_simulation_step,
initial_config=learner_configs[learner_id],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The ParallelActiveLearner.start method accepts learner_configs which contain arbitrary args and kwargs. These arguments are passed directly to sequential learners (line 482) and eventually to task functions without any sanitization. As shown in the provided examples, these task functions often use these arguments to construct shell command strings (e.g., using f-strings). If the configuration is provided by an untrusted source, such as an LLM agent (as suggested in the docstrings), an attacker can inject malicious shell commands. This is a form of Insecure Output Handling (LLM-02) where LLM-generated content is passed to a sensitive sink (shell execution) without validation. Consider implementing sanitization for task arguments (e.g., using shlex.quote) or providing a safe API for command construction that accepts a list of arguments.

async for state in sequential_learner.start(
max_iter=max_iter,
skip_pre_loop=skip_pre_loop,
initial_config=learner_configs[learner_id],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The ParallelReinforcementLearner.start method passes potentially untrusted learner_configs to sequential learners. These configurations contain args and kwargs that are used by task functions to construct commands. Without proper sanitization, this pattern is vulnerable to command injection if the configuration originates from an untrusted source like an LLM agent. Ensure that all task arguments are properly sanitized before being used in shell commands or other sensitive sinks.

num_predictions=num_predictions,
max_iter=max_iter,
skip_pre_loop=skip_pre_loop,
learning_config=learner_configs[learner_name],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The ParallelUQLearner.start method passes learner_configs to sequential learners, which eventually use the contained args and kwargs in task functions. If these task functions construct shell commands using these unsanitized arguments, it creates a command injection vulnerability. This is particularly relevant given the intended use with LLM agents, which could be manipulated into providing malicious configurations. Implement strict validation or sanitization for all task arguments.

Comment on lines +135 to +136
queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()
tasks = [asyncio.create_task(fn(queue)) for fn in run_fns]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _stream_parallel function introduces a potential resource exhaustion (Denial of Service) vulnerability. It creates an unbounded asyncio.Queue (line 135) and an unbounded number of concurrent tasks using asyncio.create_task (line 136). If a large number of parallel learners or iterations are configured, this can lead to memory exhaustion as results accumulate in the queue without backpressure, and CPU/memory exhaustion from the excessive number of concurrent tasks. Additionally, asyncio.gather (line 150) is called on all tasks at once, which can further strain resources. Consider implementing a limit on the number of concurrent tasks (e.g., using a semaphore) and setting a maxsize for the asyncio.Queue to provide backpressure.

sequential_learner.metric_values_per_iteration
)
except Exception as e:
print(f"[ActiveLearner-{learner_id}] failed with error: {e}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error logging format is inconsistent across the parallel learners (ParallelActiveLearner, ParallelReinforcementLearner, ParallelUQLearner). For better maintainability, it would be great to standardize it. A good approach is to use the class name dynamically. This change should be applied consistently across all parallel learners and their corresponding tests.

Suggested change
print(f"[ActiveLearner-{learner_id}] failed with error: {e}")
print(f"[{self.__class__.__name__}-{learner_id}] failed with error: {e}")

sequential_learner.metric_values_per_iteration
)
except Exception as e:
print(f"[RLLearner-{learner_id}] failed with error: {e}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Applying consistent error logging format for maintainability.

Suggested change
print(f"[RLLearner-{learner_id}] failed with error: {e}")
print(f"[{self.__class__.__name__}-{learner_id}] failed with error: {e}")

sequential_learner.uncertainty_values_per_iteration
)
except Exception as e:
print(f"[Parallel-Learner-{learner_name}] Failed with error: {e}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Applying consistent error logging format. Note the change from Failed to failed for consistency with other learners.

Suggested change
print(f"[Parallel-Learner-{learner_name}] Failed with error: {e}")
print(f"[{self.__class__.__name__}-{learner_name}] failed with error: {e}")


# Verify error was printed (learner 1 fails, not 0)
mock_print.assert_any_call("ActiveLearner-1] failed with error: Learner failed")
mock_print.assert_any_call("[ActiveLearner-1] failed with error: Learner failed")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion needs to be updated to match the suggested change in the error message format in rose/al/active_learner.py.

Suggested change
mock_print.assert_any_call("[ActiveLearner-1] failed with error: Learner failed")
mock_print.assert_any_call(f"[ParallelActiveLearner-1] failed with error: Learner failed")

Comment on lines +290 to +292
mock_print.assert_any_call(
"[Parallel-Learner-l1] Failed with error: Learner failed"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion needs to be updated to match the suggested change in the error message format in rose/uq/uq_active_learner.py.

Suggested change
mock_print.assert_any_call(
"[Parallel-Learner-l1] Failed with error: Learner failed"
)
mock_print.assert_any_call(
f"[ParallelUQLearner-l1] failed with error: Learner failed"
)

@AymenFJA AymenFJA merged commit c0e2028 into main Mar 5, 2026
10 checks passed
@AymenFJA AymenFJA deleted the feature/iter_parallel_learner branch March 5, 2026 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation enhancement New feature or request High Q1 raas ROSE As A Service

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant