Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/holosoma/holosoma/config_types/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from pydantic.dataclasses import dataclass

from holosoma.config_types.action import ActionManagerCfg
from holosoma.config_types.command import CommandManagerCfg
from holosoma.config_types.curriculum import CurriculumManagerCfg
Expand All @@ -14,6 +12,7 @@
from holosoma.config_types.simulator import SimulatorConfig
from holosoma.config_types.termination import TerminationManagerCfg
from holosoma.config_types.terrain import TerrainManagerCfg
from pydantic.dataclasses import dataclass


@dataclass(frozen=True)
Expand All @@ -34,6 +33,7 @@ class EnvConfig:
robot: RobotConfig
training: TrainingConfig
logger: LoggerConfig
experiment_dir: str | None = None


def get_tyro_env_config(tyro_config: ExperimentConfig) -> EnvConfig:
Expand Down
21 changes: 13 additions & 8 deletions src/holosoma/holosoma/envs/base_task/base_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import numpy as np
from pathlib import Path

import numpy as np
from holosoma.config_types.env import EnvConfig
from holosoma.config_types.full_sim import FullSimConfig
from holosoma.managers.action import ActionManager
Expand Down Expand Up @@ -76,13 +77,17 @@ def __init__(
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)

# Compute experiment directory from logger config
from holosoma.utils.experiment_paths import get_experiment_dir, get_timestamp

timestamp = get_timestamp()
experiment_dir = get_experiment_dir(
tyro_config.logger, tyro_config.training, timestamp, task_name=self._get_task_name()
)
# Use pre-computed experiment directory if provided (from train_agent.py),
# otherwise compute one (for replay, eval, etc.)
if tyro_config.experiment_dir is not None:
experiment_dir = Path(tyro_config.experiment_dir)
else:
from holosoma.utils.experiment_paths import get_experiment_dir, get_timestamp # noqa: PLC0415

timestamp = get_timestamp()
experiment_dir = get_experiment_dir(
tyro_config.logger, tyro_config.training, timestamp, task_name=self._get_task_name()
)

SimulatorClass = get_class(simulator_config._target_)
full_sim_config = FullSimConfig(
Expand Down
1 change: 1 addition & 0 deletions src/holosoma/holosoma/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def train(tyro_config: ExperimentConfig, training_context: TrainingContext | Non
env_target = tyro_config.env_class

tyro_env_config = get_tyro_env_config(tyro_config)
tyro_env_config = dataclasses.replace(tyro_env_config, experiment_dir=str(experiment_dir))
env = get_class(env_target)(tyro_env_config, device=device)

# For manager system, pre-process config AFTER env creation
Expand Down
Loading