Conversation
1-Make the ParallelLearner API emits real time results per iteration like other learners 2-Update tests 3-Update examples
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the parallel learning APIs by transforming their Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the parallel learner APIs (ParallelActiveLearner, ParallelReinforcementLearner, ParallelUQLearner) to emit real-time results per iteration using an async generator pattern. This is a significant improvement for monitoring and interacting with parallel learning processes. The implementation correctly uses asyncio.Queue for streaming states and handles errors robustly. The examples and tests have been updated accordingly.
My review includes two main points:
- A suggestion to refactor the duplicated parallel execution logic into a shared utility to improve maintainability.
- An improvement to a test assertion to make it more precise.
Overall, the changes are well-implemented and significantly improve the library's capabilities.
rose/al/active_learner.py
Outdated
| queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() | ||
|
|
||
| Raises: | ||
| Exception: Re-raises any exception from the sequential learner | ||
| with additional context about which learner failed. | ||
| """ | ||
| async def run_learner(learner_id: int) -> None: | ||
| try: | ||
| # Create and configure the sequential learner | ||
| sequential_learner: SequentialActiveLearner = self._create_sequential_learner( | ||
| learner_id, learner_configs[learner_id] | ||
| ) | ||
|
|
||
| # Convert parallel config to sequential config | ||
| sequential_config: LearnerConfig | None = self._convert_to_sequential_config( | ||
| learner_configs[learner_id] | ||
| ) | ||
|
|
||
| # Run the sequential learner by iterating through start() | ||
| final_state = None | ||
| async for state in sequential_learner.start( | ||
| max_iter=max_iter, | ||
| skip_pre_loop=skip_pre_loop, | ||
| skip_simulation_step=skip_simulation_step, | ||
| initial_config=sequential_config, | ||
| ): | ||
| final_state = state | ||
| if self.is_stopped: | ||
| sequential_learner.stop() | ||
| await queue.put(("state", dataclasses.replace(state, learner_id=learner_id))) | ||
|
|
||
| # book keep the iteration value from each learner | ||
| self.metric_values_per_iteration[f"learner-{learner_id}"] = ( | ||
| sequential_learner.metric_values_per_iteration | ||
| ) | ||
|
|
||
| return final_state | ||
| except Exception as e: | ||
| print(f"ActiveLearner-{learner_id}] failed with error: {e}") | ||
| raise | ||
|
|
||
| print(f"Starting Parallel Active Learning with {parallel_learners} learners") | ||
|
|
||
| # Submit all learners asynchronously | ||
| learners: list[Coroutine] = [active_learner_workflow(i) for i in range(parallel_learners)] | ||
|
|
||
| # Wait for all learners to complete and collect results | ||
| return await asyncio.gather(*learners) | ||
| await queue.put(("error", e)) | ||
| finally: | ||
| await queue.put(("done", None)) | ||
|
|
||
| tasks = [asyncio.create_task(run_learner(i)) for i in range(parallel_learners)] | ||
|
|
||
| completed = 0 | ||
| first_error: Exception | None = None | ||
| while completed < parallel_learners: | ||
| kind, value = await queue.get() | ||
| if kind == "done": | ||
| completed += 1 | ||
| elif kind == "state": | ||
| yield value | ||
| elif kind == "error": | ||
| if first_error is None: | ||
| first_error = value | ||
|
|
||
| await asyncio.gather(*tasks, return_exceptions=True) | ||
| if first_error is not None: | ||
| raise first_error | ||
|
|
There was a problem hiding this comment.
The core logic for running parallel learners by streaming results via an asyncio.Queue is a great pattern. However, this logic is duplicated across ParallelActiveLearner, ParallelReinforcementLearner, and ParallelUQLearner. This duplication increases maintenance overhead, as any future bug fixes or enhancements to this streaming mechanism will need to be applied in three different places.
To improve maintainability, I suggest refactoring this common logic into a shared helper function or a method in a base class. This centralizes the parallel execution pattern, making the code cleaner and easier to manage.
Here's a conceptual example of how such a helper could be structured:
# In a shared utility module or a base class
async def _run_parallel_stream(
learner_runners: list[Coroutine]
) -> AsyncIterator[IterationState]:
"""
Runs multiple learner coroutines in parallel and streams their states.
Each coroutine is expected to be a wrapper that runs a learner and
puts ('state', state), ('error', e), and ('done', None) tuples
into a shared queue.
"""
# ... implementation with queue, task creation, and result streaming ...Each start method would then be responsible for preparing its specific run_learner coroutines and passing them to this shared helper. This would significantly reduce code duplication.
| # Verify we got states from both learners (2 iterations each = 4 states) | ||
| assert len(states) > 0 |
There was a problem hiding this comment.
The comment on line 63 // Verify we got states from both learners (2 iterations each = 4 states) seems incorrect given the test setup. The stop criterion check_mse returns 0.05, which is less than the threshold of 0.1. This will cause each of the two learners to stop after the first iteration. Therefore, we should expect exactly 2 states in total.
The assertion on line 64, assert len(states) > 0, is also weaker than it could be. A more precise assertion would verify the exact number of expected states.
I suggest updating the comment and strengthening the assertion to reflect the expected behavior.
| # Verify we got states from both learners (2 iterations each = 4 states) | |
| assert len(states) > 0 | |
| # Verify we got states from both learners (1 iteration each = 2 states) | |
| assert len(states) == 2 |
…learners
- Pass learner_configs[learner_id] directly to start(), removing the sequential_config intermediate variable
- Add factory comment explaining the closure-in-loop pitfall
- Fix print typo: ActiveLearner-{id}] → [ActiveLearner-{id}]
- Apply the same changes to reinforcement_learner
- Remove sequential_config and add factory comment in uq_active_learner
- Update unit test assertion to match corrected print format
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request refactors the parallel learners to use an asynchronous streaming API, aligning them with sequential learners and improving real-time result handling. While this improves the API and promotes code reuse with the _stream_parallel helper, it introduces a potential resource exhaustion (DoS) vulnerability due to an unbounded queue and task creation in rose/learner.py. Furthermore, the parallel learner API passes unsanitized configuration arguments to task functions in rose/al/, rose/rl/, and rose/uq/, which could lead to command injection if the configuration is from an untrusted source. It is recommended to implement backpressure and task limits in the streaming logic, add robust sanitization or clear security warnings for task argument handling, and improve the consistency of error logging across different learner implementations.
| max_iter=max_iter, | ||
| skip_pre_loop=skip_pre_loop, | ||
| skip_simulation_step=skip_simulation_step, | ||
| initial_config=learner_configs[learner_id], |
There was a problem hiding this comment.
The ParallelActiveLearner.start method accepts learner_configs which contain arbitrary args and kwargs. These arguments are passed directly to sequential learners (line 482) and eventually to task functions without any sanitization. As shown in the provided examples, these task functions often use these arguments to construct shell command strings (e.g., using f-strings). If the configuration is provided by an untrusted source, such as an LLM agent (as suggested in the docstrings), an attacker can inject malicious shell commands. This is a form of Insecure Output Handling (LLM-02) where LLM-generated content is passed to a sensitive sink (shell execution) without validation. Consider implementing sanitization for task arguments (e.g., using shlex.quote) or providing a safe API for command construction that accepts a list of arguments.
| async for state in sequential_learner.start( | ||
| max_iter=max_iter, | ||
| skip_pre_loop=skip_pre_loop, | ||
| initial_config=learner_configs[learner_id], |
There was a problem hiding this comment.
The ParallelReinforcementLearner.start method passes potentially untrusted learner_configs to sequential learners. These configurations contain args and kwargs that are used by task functions to construct commands. Without proper sanitization, this pattern is vulnerable to command injection if the configuration originates from an untrusted source like an LLM agent. Ensure that all task arguments are properly sanitized before being used in shell commands or other sensitive sinks.
| num_predictions=num_predictions, | ||
| max_iter=max_iter, | ||
| skip_pre_loop=skip_pre_loop, | ||
| learning_config=learner_configs[learner_name], |
There was a problem hiding this comment.
The ParallelUQLearner.start method passes learner_configs to sequential learners, which eventually use the contained args and kwargs in task functions. If these task functions construct shell commands using these unsanitized arguments, it creates a command injection vulnerability. This is particularly relevant given the intended use with LLM agents, which could be manipulated into providing malicious configurations. Implement strict validation or sanitization for all task arguments.
| queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() | ||
| tasks = [asyncio.create_task(fn(queue)) for fn in run_fns] |
There was a problem hiding this comment.
The _stream_parallel function introduces a potential resource exhaustion (Denial of Service) vulnerability. It creates an unbounded asyncio.Queue (line 135) and an unbounded number of concurrent tasks using asyncio.create_task (line 136). If a large number of parallel learners or iterations are configured, this can lead to memory exhaustion as results accumulate in the queue without backpressure, and CPU/memory exhaustion from the excessive number of concurrent tasks. Additionally, asyncio.gather (line 150) is called on all tasks at once, which can further strain resources. Consider implementing a limit on the number of concurrent tasks (e.g., using a semaphore) and setting a maxsize for the asyncio.Queue to provide backpressure.
| sequential_learner.metric_values_per_iteration | ||
| ) | ||
| except Exception as e: | ||
| print(f"[ActiveLearner-{learner_id}] failed with error: {e}") |
There was a problem hiding this comment.
The error logging format is inconsistent across the parallel learners (ParallelActiveLearner, ParallelReinforcementLearner, ParallelUQLearner). For better maintainability, it would be great to standardize it. A good approach is to use the class name dynamically. This change should be applied consistently across all parallel learners and their corresponding tests.
| print(f"[ActiveLearner-{learner_id}] failed with error: {e}") | |
| print(f"[{self.__class__.__name__}-{learner_id}] failed with error: {e}") |
| sequential_learner.metric_values_per_iteration | ||
| ) | ||
| except Exception as e: | ||
| print(f"[RLLearner-{learner_id}] failed with error: {e}") |
| sequential_learner.uncertainty_values_per_iteration | ||
| ) | ||
| except Exception as e: | ||
| print(f"[Parallel-Learner-{learner_name}] Failed with error: {e}") |
There was a problem hiding this comment.
|
|
||
| # Verify error was printed (learner 1 fails, not 0) | ||
| mock_print.assert_any_call("ActiveLearner-1] failed with error: Learner failed") | ||
| mock_print.assert_any_call("[ActiveLearner-1] failed with error: Learner failed") |
There was a problem hiding this comment.
This assertion needs to be updated to match the suggested change in the error message format in rose/al/active_learner.py.
| mock_print.assert_any_call("[ActiveLearner-1] failed with error: Learner failed") | |
| mock_print.assert_any_call(f"[ParallelActiveLearner-1] failed with error: Learner failed") |
| mock_print.assert_any_call( | ||
| "[Parallel-Learner-l1] Failed with error: Learner failed" | ||
| ) |
There was a problem hiding this comment.
This assertion needs to be updated to match the suggested change in the error message format in rose/uq/uq_active_learner.py.
| mock_print.assert_any_call( | |
| "[Parallel-Learner-l1] Failed with error: Learner failed" | |
| ) | |
| mock_print.assert_any_call( | |
| f"[ParallelUQLearner-l1] failed with error: Learner failed" | |
| ) |
1-Make the ParallelLearner API emit real-time results per iteration, like other learners
2-Update tests
3-Update examples