diff --git a/src/holosoma/holosoma/config_types/env.py b/src/holosoma/holosoma/config_types/env.py index a296d5c2..01d20c55 100644 --- a/src/holosoma/holosoma/config_types/env.py +++ b/src/holosoma/holosoma/config_types/env.py @@ -34,6 +34,7 @@ class EnvConfig: robot: RobotConfig training: TrainingConfig logger: LoggerConfig + experiment_dir: str | None = None def get_tyro_env_config(tyro_config: ExperimentConfig) -> EnvConfig: diff --git a/src/holosoma/holosoma/envs/base_task/base_task.py b/src/holosoma/holosoma/envs/base_task/base_task.py index e46675a6..9e23cd5a 100644 --- a/src/holosoma/holosoma/envs/base_task/base_task.py +++ b/src/holosoma/holosoma/envs/base_task/base_task.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pathlib import Path + import numpy as np from holosoma.config_types.env import EnvConfig @@ -14,6 +16,7 @@ from holosoma.managers.termination import TerminationManager from holosoma.managers.terrain import TerrainManager from holosoma.simulator.base_simulator.base_simulator import BaseSimulator +from holosoma.utils.experiment_paths import get_experiment_dir, get_timestamp from holosoma.utils.helpers import get_class from holosoma.utils.safe_torch_import import torch from holosoma.utils.torch_utils import to_torch @@ -76,13 +79,15 @@ def __init__( torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) - # Compute experiment directory from logger config - from holosoma.utils.experiment_paths import get_experiment_dir, get_timestamp - - timestamp = get_timestamp() - experiment_dir = get_experiment_dir( - tyro_config.logger, tyro_config.training, timestamp, task_name=self._get_task_name() - ) + # Use pre-computed experiment directory if provided (from train_agent.py), + # otherwise compute one (for replay, eval, etc.) + if tyro_config.experiment_dir is not None: + experiment_dir = Path(tyro_config.experiment_dir) + else: + timestamp = get_timestamp() + experiment_dir = get_experiment_dir( + tyro_config.logger, tyro_config.training, timestamp, task_name=self._get_task_name() + ) SimulatorClass = get_class(simulator_config._target_) full_sim_config = FullSimConfig( diff --git a/src/holosoma/holosoma/train_agent.py b/src/holosoma/holosoma/train_agent.py index a6a40546..44e918f2 100644 --- a/src/holosoma/holosoma/train_agent.py +++ b/src/holosoma/holosoma/train_agent.py @@ -249,6 +249,7 @@ def train(tyro_config: ExperimentConfig, training_context: TrainingContext | Non env_target = tyro_config.env_class tyro_env_config = get_tyro_env_config(tyro_config) + tyro_env_config = dataclasses.replace(tyro_env_config, experiment_dir=str(experiment_dir)) env = get_class(env_target)(tyro_env_config, device=device) # For manager system, pre-process config AFTER env creation