Conversation
1-Base Tracking capabilites 2-PNP approach for trackers 3-Native File based exp. tracker 4-MLFLOW exp. tracker 5-ClearML exp. tracker 6-Extend examples 7-Extend docs
1-Test suite for tracking 2-Examples for tracking 3-Update docs
1-Modify the approach so it aligns with ROSE working on OUTER loop of the learning and not the INNER Loop 2-Update Docs 3-Update Examples 4-Update Tests
|
Warning Gemini is experiencing higher than usual traffic and was unable to create the summary. Please try again in a few hours by commenting |
|
/gemini summary |
Summary of ChangesThis pull request introduces a robust and flexible experiment tracking system to ROSE, significantly enhancing its observability capabilities. By implementing a pluggable Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
|
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces a well-designed and comprehensive experiment tracking system to ROSE, featuring a new TrackerBase protocol and integrations for file-based logging, MLflow, and ClearML. While the design, documentation, and test coverage are strong, a critical security vulnerability has been identified: a high-severity insecure deserialization issue in example scripts due to pickle usage on predictable file paths in /tmp. Additionally, there's a medium-severity concern regarding potential leakage of sensitive information from task decorators to tracking services. Addressing these security concerns is paramount, and further areas for improvement are detailed in specific comments.
|
|
||
| def load_state() -> dict: | ||
| with open(DATA_FILE, "rb") as f: | ||
| return pickle.load(f) |
There was a problem hiding this comment.
The use of pickle.load() on a file located in a world-writable directory (/tmp) with a predictable filename is insecure. A local attacker could place a malicious pickle file at this location to achieve arbitrary code execution when this example script is run. This is particularly risky in shared environments like HPC clusters. Consider using a safer serialization format like JSON or ensuring the file is created with secure permissions and a non-predictable name.
| def load_state(name: str) -> dict: | ||
| path = Path(str(DATA_FILE).format(name=name)) | ||
| with open(path, "rb") as f: | ||
| return pickle.load(f) |
There was a problem hiding this comment.
The use of pickle.load() on a file located in a world-writable directory (/tmp) with a predictable filename is insecure. A local attacker could place a malicious pickle file at this location to achieve arbitrary code execution when this example script is run. Consider using a safer serialization format like JSON or ensuring the file is created with secure permissions and a non-predictable name.
|
|
||
| def load_state() -> dict: | ||
| with open(DATA_FILE, "rb") as f: | ||
| return pickle.load(f) |
There was a problem hiding this comment.
The use of pickle.load() on a file located in a world-writable directory (/tmp) with a predictable filename is insecure. A local attacker could place a malicious pickle file at this location to achieve arbitrary code execution when this example script is run. Consider using a safer serialization format like JSON or ensuring the file is created with secure permissions and a non-predictable name.
rose/integrations/clearml_tracker.py
Outdated
| for k, v in task_manifest.decor_kwargs.items(): | ||
| params[f"task/{task_key}/{k}"] = v |
There was a problem hiding this comment.
The tracker logs all keyword arguments from task decorators (decor_kwargs) to ClearML. If a user passes sensitive information (e.g., API keys, tokens) to a decorator, it will be logged and stored in the tracking service. It is recommended to filter out common sensitive keys before logging.
| for k, v in task_manifest.decor_kwargs.items(): | |
| params[f"task/{task_key}/{k}"] = v | |
| for k, v in task_manifest.decor_kwargs.items(): | |
| if k.lower() not in {"api_key", "password", "secret", "token"}: | |
| params[f"task/{task_key}/{k}"] = v |
rose/integrations/mlflow_tracker.py
Outdated
| for k, v in task_manifest.decor_kwargs.items(): | ||
| params[f"task.{task_key}.{k}"] = v |
There was a problem hiding this comment.
The tracker logs all keyword arguments from task decorators (decor_kwargs) to MLflow. If a user passes sensitive information (e.g., API keys, tokens) to a decorator, it will be logged and stored in the tracking service. It is recommended to filter out common sensitive keys before logging.
| for k, v in task_manifest.decor_kwargs.items(): | |
| params[f"task.{task_key}.{k}"] = v | |
| for k, v in task_manifest.decor_kwargs.items(): | |
| if k.lower() not in {"api_key", "password", "secret", "token"}: | |
| params[f"task.{task_key}.{k}"] = v |
| def on_stop(self, final_state, reason: str) -> None: | ||
| super().on_stop(final_state, reason) | ||
| if final_state and reason in ("criterion_met", "max_iter_reached"): | ||
| model = load_model(final_state.get("checkpoint_path")) | ||
| mlflow.sklearn.log_model(model, artifact_path="surrogate_model") | ||
| ``` |
There was a problem hiding this comment.
The example for extending MLflowTracker has a small bug. It calls super().on_stop() before logging the model artifact. The base on_stop method calls mlflow.end_run(), which terminates the MLflow run. Any subsequent calls to log artifacts will either fail or start a new, separate run.
To ensure all logging happens within the same active run, the super().on_stop() call should be moved to the end of the method, after the artifact has been logged. This pattern is correctly used in the ClearMLTracker extension example.
| def on_stop(self, final_state, reason: str) -> None: | |
| super().on_stop(final_state, reason) | |
| if final_state and reason in ("criterion_met", "max_iter_reached"): | |
| model = load_model(final_state.get("checkpoint_path")) | |
| mlflow.sklearn.log_model(model, artifact_path="surrogate_model") | |
| ``` | |
| class MLflowArtifactTracker(MLflowTracker): | |
| def on_stop(self, final_state, reason: str) -> None: | |
| if final_state and reason in ("criterion_met", "max_iter_reached"): | |
| model = load_model(final_state.get("checkpoint_path")) | |
| mlflow.sklearn.log_model(model, artifact_path="surrogate_model") | |
| super().on_stop(final_state, reason) |
| _stop_reason = "max_iter_reached" | ||
| try: | ||
| async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): | ||
| self._notify_trackers_iteration(state) | ||
| yield state | ||
| finally: | ||
| self._notify_trackers_stop(self._iteration_state, _stop_reason) |
There was a problem hiding this comment.
The _stop_reason for the ParallelActiveLearner is initialized to "max_iter_reached" and is never updated. If the user breaks out of the async for loop, the finally block will execute and _notify_trackers_stop will be called with the incorrect reason. It should report "stopped" in this case.
A better pattern would be to default the reason to "stopped" and only update it to "max_iter_reached" if the loop completes fully.
| _stop_reason = "max_iter_reached" | |
| try: | |
| async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): | |
| self._notify_trackers_iteration(state) | |
| yield state | |
| finally: | |
| self._notify_trackers_stop(self._iteration_state, _stop_reason) | |
| _stop_reason = "stopped" | |
| try: | |
| async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): | |
| self._notify_trackers_iteration(state) | |
| yield state | |
| _stop_reason = "max_iter_reached" | |
| finally: | |
| self._notify_trackers_stop(self._iteration_state, _stop_reason) |
| _stop_reason = "max_iter_reached" | ||
| try: | ||
| async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): | ||
| self._notify_trackers_iteration(state) | ||
| yield state | ||
| finally: | ||
| self._notify_trackers_stop(self._iteration_state, _stop_reason) |
There was a problem hiding this comment.
The _stop_reason for the ParallelReinforcementLearner is initialized to "max_iter_reached" and is never updated. If the user breaks out of the async for loop, the finally block will execute and _notify_trackers_stop will be called with the incorrect reason. It should report "stopped" in this case.
A better pattern would be to default the reason to "stopped" and only update it to "max_iter_reached" if the loop completes fully.
| _stop_reason = "max_iter_reached" | |
| try: | |
| async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): | |
| self._notify_trackers_iteration(state) | |
| yield state | |
| finally: | |
| self._notify_trackers_stop(self._iteration_state, _stop_reason) | |
| _stop_reason = "stopped" | |
| try: | |
| async for state in _stream_parallel([make_run_fn(i) for i in range(parallel_learners)]): | |
| self._notify_trackers_iteration(state) | |
| yield state | |
| _stop_reason = "max_iter_reached" | |
| finally: | |
| self._notify_trackers_stop(self._iteration_state, _stop_reason) |
| _stop_reason = "max_iter_reached" | ||
| try: | ||
| async for state in _stream_parallel([make_run_fn(name) for name in learner_names]): | ||
| self._notify_trackers_iteration(state) | ||
| yield state | ||
| finally: | ||
| self._notify_trackers_stop(self._iteration_state, _stop_reason) |
There was a problem hiding this comment.
The _stop_reason for the ParallelUQLearner is initialized to "max_iter_reached" and is never updated. If the user breaks out of the async for loop, the finally block will execute and _notify_trackers_stop will be called with the incorrect reason. It should report "stopped" in this case.
A better pattern would be to default the reason to "stopped" and only update it to "max_iter_reached" if the loop completes fully.
| _stop_reason = "max_iter_reached" | |
| try: | |
| async for state in _stream_parallel([make_run_fn(name) for name in learner_names]): | |
| self._notify_trackers_iteration(state) | |
| yield state | |
| finally: | |
| self._notify_trackers_stop(self._iteration_state, _stop_reason) | |
| _stop_reason = "stopped" | |
| try: | |
| async for state in _stream_parallel([make_run_fn(name) for name in learner_names]): | |
| self._notify_trackers_iteration(state) | |
| yield state | |
| _stop_reason = "max_iter_reached" | |
| finally: | |
| self._notify_trackers_stop(self._iteration_state, _stop_reason) |
1-Isolate the track parameters from the decor kwargs and uses log_params 2-Fix examples to be more realistic 3-Update tests 4-Test with MLFlow and ClearML UI
1-Fix 1 — ClearML series names: ClearMLTracker adds an optional learner_names parameter, improves series naming, and fixes a bug where learner_id=0 was incorrectly mapped to value. 2-Fix 2 — Non-numeric config logging: Add string config logging for MLflow (via tags) and ClearML (via connected hyperparameters) and move _TASK_NAMES to module level to avoid recreating it each iteration.
This PR introduces experiment tracking capabilities:
This PR only supports ROSE tracking capabilities on the learner level and specifically the outer loop of the learning approach.
MLFLOW
CLEAR ML