diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index f6164ed1d..c5919ccb7 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -16,27 +16,46 @@ import logging import torch -from omegaconf import ListConfig, OmegaConf -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset -from lerobot.common.datasets.transforms import get_image_transforms +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, + MultiLeRobotDataset, +) +from lerobot.common.datasets.transforms import ImageTransforms +from lerobot.configs.default import MainConfig +from lerobot.configs.policies import PretrainedConfig +IMAGENET_STATS = { + "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) + "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1) +} -def resolve_delta_timestamps(cfg): + +def resolve_delta_timestamps( + cfg: PretrainedConfig, ds_meta: LeRobotDatasetMetadata +) -> dict[str, list] | None: """Resolves delta_timestamps config key (in-place) by using `eval`. Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by the data type of its values). """ - delta_timestamps = cfg.training.get("delta_timestamps") - if delta_timestamps is not None: - for key in delta_timestamps: - if isinstance(delta_timestamps[key], str): - # TODO(rcadene, alexander-soare): remove `eval` to avoid exploit - cfg.training.delta_timestamps[key] = eval(delta_timestamps[key]) + delta_timestamps = {} + for key in ds_meta.features: + if key == "next.reward" and cfg.reward_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] + if key == "action" and cfg.action_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] + if key.startswith("observation.") and cfg.observation_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] + + if len(delta_timestamps) == 0: + delta_timestamps = None + + return delta_timestamps -def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset: +def make_dataset(cfg: MainConfig, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset: """ Args: cfg: A Hydra config as per the LeRobot config scheme. @@ -50,67 +69,58 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData Returns: The LeRobotDataset. """ - if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): - raise ValueError( - "Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of " - "strings to load multiple datasets." - ) - # A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora). - if cfg.env.name != "dora": - if isinstance(cfg.dataset_repo_id, str): - dataset_repo_ids = [cfg.dataset_repo_id] # single dataset + if cfg.env.type != "dora": + if isinstance(cfg.dataset.repo_id, str): + dataset_repo_ids = [cfg.dataset.repo_id] # single dataset + elif isinstance(cfg.dataset.repo_id, list): + dataset_repo_ids = cfg.dataset.repo_id # multiple datasets else: - dataset_repo_ids = cfg.dataset_repo_id # multiple datasets + raise ValueError( + "Expected cfg.dataset.repo_id to be either a single string to load one dataset or a list of " + "strings to load multiple datasets." + ) for dataset_repo_id in dataset_repo_ids: - if cfg.env.name not in dataset_repo_id: + if cfg.env.type not in dataset_repo_id: logging.warning( f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your " - f"environment ({cfg.env.name=})." + f"environment ({cfg.env.type=})." ) - resolve_delta_timestamps(cfg) - - image_transforms = None - if cfg.training.image_transforms.enable: - cfg_tf = cfg.training.image_transforms - image_transforms = get_image_transforms( - brightness_weight=cfg_tf.brightness.weight, - brightness_min_max=cfg_tf.brightness.min_max, - contrast_weight=cfg_tf.contrast.weight, - contrast_min_max=cfg_tf.contrast.min_max, - saturation_weight=cfg_tf.saturation.weight, - saturation_min_max=cfg_tf.saturation.min_max, - hue_weight=cfg_tf.hue.weight, - hue_min_max=cfg_tf.hue.min_max, - sharpness_weight=cfg_tf.sharpness.weight, - sharpness_min_max=cfg_tf.sharpness.min_max, - max_num_transforms=cfg_tf.max_num_transforms, - random_order=cfg_tf.random_order, - ) + image_transforms = ( + ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None + ) - if isinstance(cfg.dataset_repo_id, str): + if isinstance(cfg.dataset.repo_id, str): # TODO (aliberts): add 'episodes' arg from config after removing hydra + ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only) + delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) dataset = LeRobotDataset( - cfg.dataset_repo_id, - delta_timestamps=cfg.training.get("delta_timestamps"), + cfg.dataset.repo_id, + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, image_transforms=image_transforms, video_backend=cfg.video_backend, + local_files_only=cfg.dataset.local_files_only, ) else: dataset = MultiLeRobotDataset( - cfg.dataset_repo_id, - delta_timestamps=cfg.training.get("delta_timestamps"), + cfg.dataset.repo_id, + # TODO(aliberts): add proper support for multi dataset + # delta_timestamps=delta_timestamps, image_transforms=image_transforms, video_backend=cfg.video_backend, ) - if cfg.get("override_dataset_stats"): - for key, stats_dict in cfg.override_dataset_stats.items(): - for stats_type, listconfig in stats_dict.items(): - # example of stats_type: min, max, mean, std - stats = OmegaConf.to_container(listconfig, resolve=True) + if cfg.dataset.use_imagenet_stats: + for key in dataset.meta.camera_keys: + for stats_type, stats in IMAGENET_STATS.items(): dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + # for key, stats_dict in cfg.override_dataset_stats.items(): + # for stats_type, listconfig in stats_dict.items(): + # # example of stats_type: min, max, mean, std + # stats = OmegaConf.to_container(listconfig, resolve=True) + # dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) return dataset diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 899f0d66c..696ccb02e 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +from dataclasses import dataclass, field from typing import Any, Callable, Dict, Sequence import torch @@ -137,6 +138,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) +# TODO(aliberts): Remove def get_image_transforms( brightness_weight: float = 1.0, brightness_min_max: tuple[float, float] | None = None, @@ -195,3 +197,116 @@ def check_value(name, weight, min_max): else: # TODO(rcadene, aliberts): add v2.ToDtype float16? return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order) + + +@dataclass +class ImageTransformConfig: + """ + For each transform, the following parameters are available: + weight: This represents the multinomial probability (with no replacement) + used for sampling the transform. If the sum of the weights is not 1, + they will be normalized. + type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a + custom transform defined here. + kwargs: Lower & upper bound respectively used for sampling the transform's parameter + (following uniform distribution) when it's applied. + """ + + weight: int = 1.0 + type: str = "Identity" + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ImageTransformsConfig: + """ + These transforms are all using standard torchvision.transforms.v2 + You can find out how these transformations affect images here: + https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html + We use a custom RandomSubsetApply container to sample them. + """ + + # Set this flag to `true` to enable transforms during training + enable: bool = False + # This is the maximum number of transforms (sampled from these below) that will be applied to each frame. + # It's an integer in the interval [1, number_of_available_transforms]. + max_num_transforms: int = 3 + # By default, transforms are applied in Torchvision's suggested order (shown below). + # Set this to True to apply them in a random order. + random_order: bool = False + tfs: list[ImageTransformConfig] = field( + default_factory=lambda: [ + ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"brightness": (0.8, 1.2)}, + ), + ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"contrast": (0.8, 1.2)}, + ), + ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"saturation": (0.5, 1.5)}, + ), + ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"hue": (-0.05, 0.05)}, + ), + ImageTransformConfig( + weight=1.0, + type="SharpnessJitter", + kwargs={"sharpness": (0.5, 1.5)}, + ), + ] + ) + + +class ImageTransforms(Transform): + """A class to compose image transforms based on configuration.""" + + _registry = { + "Identity": v2.Identity, + "ColorJitter": v2.ColorJitter, + "SharpnessJitter": SharpnessJitter, + } + + def __init__(self, cfg: ImageTransformsConfig) -> None: + super().__init__() + self._cfg = cfg + + weights = [] + transforms = [] + for tf_cfg in cfg.tfs: + if tf_cfg.weight <= 0.0: + continue + + transform_cls = self._registry.get(tf_cfg.type) + if transform_cls is None: + available_transforms = ", ".join(self._registry.keys()) + raise ValueError( + f"Transform '{tf_cfg.type}' not found in the registry. " + f"Available transforms are: {available_transforms}" + ) + + # Instantiate the transform + transform_instance = transform_cls(**tf_cfg.kwargs) + transforms.append(transform_instance) + weights.append(tf_cfg.weight) + + n_subset = min(len(transforms), cfg.max_num_transforms) + if n_subset == 0 or not cfg.enable: + self.transform = v2.Identity() + else: + self.transform = RandomSubsetApply( + transforms=transforms, + p=weights, + n_subset=n_subset, + random_order=cfg.random_order, + ) + + def forward(self, *inputs: Any) -> Any: + return self.transform(*inputs) diff --git a/lerobot/common/envs/__init__.py b/lerobot/common/envs/__init__.py new file mode 100644 index 000000000..490b487c0 --- /dev/null +++ b/lerobot/common/envs/__init__.py @@ -0,0 +1 @@ +from .configs import AlohaEnv, EnvConfig, RealEnv # noqa: F401 diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py new file mode 100644 index 000000000..b4414dee6 --- /dev/null +++ b/lerobot/common/envs/configs.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +import draccus + + +@dataclass +class EnvConfig(draccus.ChoiceRegistry): + task: str | None = None + state_dim: int = 18 + action_dim: int = 18 + fps: int = 30 + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@EnvConfig.register_subclass("real_world") +@dataclass +class RealEnv(EnvConfig): + pass + + +@EnvConfig.register_subclass("aloha") +@dataclass +class AlohaEnv(EnvConfig): + task: str = "AlohaInsertion-v0" + state_dim: int = 14 + action_dim: int = 14 + fps: int = 50 diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 3bd2df89a..b534c92b8 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -33,6 +33,7 @@ from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state +from lerobot.configs.default import MainConfig def log_output_dir(out_dir): @@ -42,9 +43,9 @@ def log_output_dir(out_dir): def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" lst = [ - f"policy:{cfg.policy.name}", - f"dataset:{cfg.dataset_repo_id}", - f"env:{cfg.env.name}", + f"policy:{cfg.policy.type}", + f"dataset:{cfg.dataset.repo_id}", + f"env:{cfg.env.type}", f"seed:{cfg.seed}", ] return lst if return_list else "-".join(lst) @@ -83,25 +84,18 @@ class Logger: pretrained_model_dir_name = "pretrained_model" training_state_file_name = "training_state.pth" - def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None): - """ - Args: - log_dir: The directory to save all logs and training outputs to. - job_name: The WandB job name. - """ + def __init__(self, cfg: MainConfig): self._cfg = cfg - self.log_dir = Path(log_dir) + self.log_dir = cfg.dir self.log_dir.mkdir(parents=True, exist_ok=True) - self.checkpoints_dir = self.get_checkpoints_dir(log_dir) - self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir) - self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir) + self.job_name = cfg.job_name + self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir) + self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir) + self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir) # Set up WandB. self._group = cfg_to_group(cfg) - project = cfg.get("wandb", {}).get("project") - entity = cfg.get("wandb", {}).get("entity") - enable_wandb = cfg.get("wandb", {}).get("enable", False) - run_offline = not enable_wandb or not project + run_offline = not cfg.wandb.enable or not cfg.wandb.project if run_offline: logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) self._wandb = None @@ -115,12 +109,12 @@ def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = N wandb.init( id=wandb_run_id, - project=project, - entity=entity, - name=wandb_job_name, - notes=cfg.get("wandb", {}).get("notes"), + project=cfg.wandb.project, + entity=cfg.wandb.entity, + name=self.job_name, + notes=cfg.wandb.notes, tags=cfg_to_group(cfg, return_list=True), - dir=log_dir, + dir=self.log_dir, config=OmegaConf.to_container(cfg, resolve=True), # TODO(rcadene): try set to True save_code=False, diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py new file mode 100644 index 000000000..58db9849f --- /dev/null +++ b/lerobot/common/policies/__init__.py @@ -0,0 +1,4 @@ +from .act.configuration_act import ACTConfig as ACTConfig +from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig +from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig +from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a86c359c9..5c0d20934 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -15,9 +15,12 @@ # limitations under the License. from dataclasses import dataclass, field +from lerobot.configs.policies import PretrainedConfig + +@PretrainedConfig.register_subclass("act") @dataclass -class ACTConfig: +class ACTConfig(PretrainedConfig): """Configuration class for the Action Chunking Transformers policy. Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". @@ -169,3 +172,15 @@ def __post_init__(self): and "observation.environment_state" not in self.input_shapes ): raise ValueError("You must provide at least one image or the environment state among the inputs.") + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 531f49e4d..64a66dfd1 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -16,9 +16,12 @@ # limitations under the License. from dataclasses import dataclass, field +from lerobot.configs.policies import PretrainedConfig + +@PretrainedConfig.register_subclass("diffusion") @dataclass -class DiffusionConfig: +class DiffusionConfig(PretrainedConfig): """Configuration class for DiffusionPolicy. Defaults are configured for training with PushT providing proprioceptive and single camera observations. @@ -207,3 +210,15 @@ def __post_init__(self): "The horizon should be an integer multiple of the downsampling factor (which is determined " f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" ) + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 4a5415a15..7705a9c0c 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -16,9 +16,12 @@ # limitations under the License. from dataclasses import dataclass, field +from lerobot.configs.policies import PretrainedConfig + +@PretrainedConfig.register_subclass("tdmpc") @dataclass -class TDMPCConfig: +class TDMPCConfig(PretrainedConfig): """Configuration class for TDMPCPolicy. Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single @@ -102,6 +105,7 @@ class TDMPCConfig: """ # Input / output structure. + n_obs_steps: int = 1 n_action_repeats: int = 2 horizon: int = 5 n_action_steps: int = 1 @@ -185,6 +189,10 @@ def __post_init__(self): f"advised that you stick with the default. See {self.__class__.__name__} docstring for more " "information." ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) if self.n_action_steps > 1: if self.n_action_repeats != 1: raise ValueError( @@ -194,3 +202,15 @@ def __post_init__(self): raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") if self.n_action_steps > self.horizon: raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") + + @property + def observation_delta_indices(self) -> list: + return list(range(self.horizon + 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return list(range(self.horizon)) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index dfe4684d2..4d8ce94cf 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -18,9 +18,12 @@ from dataclasses import dataclass, field +from lerobot.configs.policies import PretrainedConfig + +@PretrainedConfig.register_subclass("vqbet") @dataclass -class VQBeTConfig: +class VQBeTConfig(PretrainedConfig): """Configuration class for VQ-BeT. Defaults are configured for training with PushT providing proprioceptive and single camera observations. @@ -165,3 +168,15 @@ def __post_init__(self): f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " "expect all image shapes to match." ) + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py new file mode 100644 index 000000000..54c2b9302 --- /dev/null +++ b/lerobot/configs/default.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import logging +from dataclasses import dataclass, field +from pathlib import Path +from pprint import pformat + +import draccus +from deepdiff import DeepDiff + +from lerobot.common import ( + envs, # noqa: F401 + policies, # noqa: F401 +) +from lerobot.common.datasets.transforms import ImageTransformsConfig +from lerobot.configs.policies import PretrainedConfig + + +@dataclass +class OfflineConfig: + steps: int = 100_000 + + +@dataclass +class OnlineConfig: + """ + The online training look looks something like: + + ```python + for i in range(steps): + do_online_rollout_and_update_online_buffer() + for j in range(steps_between_rollouts): + batch = next(dataloader_with_offline_and_online_data) + loss = policy(batch) + loss.backward() + optimizer.step() + ``` + + Note that the online training loop adopts most of the options from the offline loop unless specified + otherwise. + """ + + steps: int = 0 + # How many episodes to collect at once when we reach the online rollout part of the training loop. + rollout_n_episodes: int = 1 + # The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for + # the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes. + rollout_batch_size: int = 1 + # How many optimization steps (forward, backward, optimizer step) to do between running rollouts. + steps_between_rollouts: int | None = None + # The proportion of online samples (vs offline samples) to include in the online training batches. + sampling_ratio: float = 0.5 + # First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1. + env_seed: int | None = None + # Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is + # FIFO. + buffer_capacity: int | None = None + # The minimum number of frames to have in the online buffer before commencing online training. + # If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the + # seed size condition is satisfied. + buffer_seed_size: int = 0 + # Whether to run the online rollouts asynchronously. This means we can run the online training steps in + # parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training + # + eval + environment rendering simultaneously. + do_online_rollout_async: bool = False + + +@dataclass +class TrainConfig: + # Number of workers for the dataloader. + num_workers: int = 4 + batch_size: int = 8 + eval_freq: int = 20_000 + log_freq: int = 200 + save_checkpoint: bool = True + # Checkpoint is saved every `save_freq` training iterations and after the last training step. + save_freq: int = 20_000 + offline: OfflineConfig = field(default_factory=OfflineConfig) + online: OnlineConfig = field(default_factory=OnlineConfig) + + +@dataclass +class DatasetConfig: + repo_id: str | list[str] + episodes: list[int] | None = None + image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) + local_files_only: bool = False + use_imagenet_stats: bool = True + + +@dataclass +class EvalConfig: + n_episodes: int = 50 + # `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv. + batch_size: int = 50 + # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). + use_async_envs: bool = False + + def __post_init__(self): + if self.batch_size > self.n_episodes: + raise ValueError( + "The eval batch size is greater than the number of eval episodes " + f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} " + f"eval environments will be instantiated, but only {self.n_episodes} will be used. " + "This might significantly slow down evaluation. To fix this, you should update your command " + f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " + f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." + ) + + +@dataclass +class WandBConfig: + enable: bool = False + # Set to true to disable saving an artifact despite training.save_checkpoint=True + disable_artifact: bool = False + project: str = "lerobot" + entity: str | None = None + notes: str | None = None + + +@dataclass +class MainConfig: + policy: PretrainedConfig + dataset: DatasetConfig + env: envs.EnvConfig = envs.RealEnv + # Set `dir` to where you would like to save all of the run outputs. If you run another training session + # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. + dir: Path | None = None + job_name: str | None = None + # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure + # `dir` is the directory of an existing run with at least one checkpoint in it. + # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, + # regardless of what's provided with the training command at the time of resumption. + resume: bool = False + device: str = "cuda" # | cpu | mp + # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, + # automatic gradient scaling is used. + use_amp: bool = False + # `seed` is used for training (eg: model initialization, dataset shuffling) + # AND for the evaluation environments. + seed: int | None = None + # You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data + # keys common between the datasets are kept. Each dataset gets and additional transform that inserts the + # "dataset_index" into the returned item. The index mapping is made according to the order in which the + # datsets are provided. + video_backend: str = "pyav" + training: TrainConfig = field(default_factory=TrainConfig) + eval: EvalConfig = field(default_factory=EvalConfig) + wandb: WandBConfig = field(default_factory=WandBConfig) + + def __post_init__(self): + if not self.job_name: + self.job_name = f"{self.env.type}_{self.policy.type}" + + if not self.dir: + now = dt.datetime.now() + train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" + self.dir = Path("outputs/train") / train_dir + + if self.training.online.steps > 0 and isinstance(self.dataset.repo_id, list): + raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") + + # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need + # to check for any differences between the provided config and the checkpoint's config. + checkpoint_cfg_path = self.dir / "checkpoints/last/config.yaml" + if self.resume: + if not checkpoint_cfg_path.exists(): + raise RuntimeError( + f"You have set resume=True, but there is no model checkpoint in {self.dir}" + ) + + # Get the configuration file from the last checkpoint. + checkpoint_cfg = self.from_checkpoint(checkpoint_cfg_path) + + # # Check for differences between the checkpoint configuration and provided configuration. + # # Hack to resolve the delta_timestamps ahead of time in order to properly diff. + # resolve_delta_timestamps(cfg) + diff = DeepDiff(checkpoint_cfg, self) + # Ignore the `resume` and parameters. + if "values_changed" in diff and "root['resume']" in diff["values_changed"]: + del diff["values_changed"]["root['resume']"] + # Log a warning about differences between the checkpoint configuration and the provided + # configuration. + if len(diff) > 0: + logging.warning( + "At least one difference was detected between the checkpoint configuration and " + f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration " + "takes precedence.", + ) + # Use the checkpoint config instead of the provided config (but keep `resume` parameter). + self = checkpoint_cfg + self.resume = True + + elif checkpoint_cfg_path.exists(): + raise RuntimeError( + f"The configured output directory {checkpoint_cfg_path} already exists. If " + "you meant to resume training, please use `resume=true` in your command or yaml configuration." + ) + + @classmethod + def from_checkpoint(cls, config_path: Path): + return draccus.load(cls, config_path) diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py new file mode 100644 index 000000000..7c7d80c16 --- /dev/null +++ b/lerobot/configs/policies.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass, field + +import draccus + + +@dataclass +class PretrainedConfig(draccus.ChoiceRegistry): + """ + Base configuration class for policy models. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + input_shapes: A dictionary defining the shapes of the input data for the policy. + output_shapes: A dictionary defining the shapes of the output data for the policy. + input_normalization_modes: A dictionary with key representing the modality and the value specifies the + normalization mode to apply. + output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to + the original scale. + """ + + n_obs_steps: int = 1 + input_shapes: dict[str, list[int]] = field(default_factory=lambda: {}) + output_shapes: dict[str, list[int]] = field(default_factory=lambda: {}) + input_normalization_modes: dict[str, str] = field(default_factory=lambda: {}) + output_normalization_modes: dict[str, str] = field(default_factory=lambda: {}) + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + @property + def observation_delta_indices(self) -> list | None: + raise NotImplementedError + + @property + def action_delta_indices(self) -> list | None: + raise NotImplementedError + + @property + def reward_delta_indices(self) -> list | None: + raise NotImplementedError diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9a0b7e4cb..d8d18b6ae 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -18,16 +18,15 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from copy import deepcopy +from dataclasses import asdict from pathlib import Path from pprint import pformat from threading import Lock +import draccus import hydra import numpy as np import torch -from deepdiff import DeepDiff -from omegaconf import DictConfig, ListConfig, OmegaConf -from termcolor import colored from torch import nn from torch.cuda.amp import GradScaler @@ -44,14 +43,14 @@ from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, - init_hydra_config, init_logging, set_global_seed, ) +from lerobot.configs.default import MainConfig from lerobot.scripts.eval import eval_policy -def make_optimizer_and_scheduler(cfg, policy): +def make_optimizer_and_scheduler(cfg: MainConfig, policy): if cfg.policy.name == "act": optimizer_params_dicts = [ { @@ -234,74 +233,76 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online): logger.log_dict(info, step, mode="eval") -def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): - if out_dir is None: - raise NotImplementedError() - if job_name is None: - raise NotImplementedError() +@draccus.wrap() +def train(cfg: MainConfig, out_dir: str | None = None, job_name: str | None = None): + # if out_dir is None: + # raise NotImplementedError() + # if job_name is None: + # raise NotImplementedError() init_logging() - logging.info(pformat(OmegaConf.to_container(cfg))) - - if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): - raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") - - # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need - # to check for any differences between the provided config and the checkpoint's config. - if cfg.resume: - if not Logger.get_last_checkpoint_dir(out_dir).exists(): - raise RuntimeError( - "You have set resume=True, but there is no model checkpoint in " - f"{Logger.get_last_checkpoint_dir(out_dir)}" - ) - checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") - logging.info( - colored( - "You have set resume=True, indicating that you wish to resume a run", - color="yellow", - attrs=["bold"], - ) - ) - # Get the configuration file from the last checkpoint. - checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) - # Check for differences between the checkpoint configuration and provided configuration. - # Hack to resolve the delta_timestamps ahead of time in order to properly diff. - resolve_delta_timestamps(cfg) - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) - # Ignore the `resume` and parameters. - if "values_changed" in diff and "root['resume']" in diff["values_changed"]: - del diff["values_changed"]["root['resume']"] - # Log a warning about differences between the checkpoint configuration and the provided - # configuration. - if len(diff) > 0: - logging.warning( - "At least one difference was detected between the checkpoint configuration and " - f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration " - "takes precedence.", - ) - # Use the checkpoint config instead of the provided config (but keep `resume` parameter). - cfg = checkpoint_cfg - cfg.resume = True - elif Logger.get_last_checkpoint_dir(out_dir).exists(): - raise RuntimeError( - f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If " - "you meant to resume training, please use `resume=true` in your command or yaml configuration." - ) - - if cfg.eval.batch_size > cfg.eval.n_episodes: - raise ValueError( - "The eval batch size is greater than the number of eval episodes " - f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} " - f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. " - "This might significantly slow down evaluation. To fix this, you should update your command " - f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), " - f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)." - ) + logging.info(pformat(asdict(cfg))) + + # if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): + # raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") + + # # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need + # # to check for any differences between the provided config and the checkpoint's config. + # if cfg.resume: + # if not Logger.get_last_checkpoint_dir(out_dir).exists(): + # raise RuntimeError( + # "You have set resume=True, but there is no model checkpoint in " + # f"{Logger.get_last_checkpoint_dir(out_dir)}" + # ) + # checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + # logging.info( + # colored( + # "You have set resume=True, indicating that you wish to resume a run", + # color="yellow", + # attrs=["bold"], + # ) + # ) + # # Get the configuration file from the last checkpoint. + # checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) + # # Check for differences between the checkpoint configuration and provided configuration. + # # Hack to resolve the delta_timestamps ahead of time in order to properly diff. + # resolve_delta_timestamps(cfg) + # diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + # # Ignore the `resume` and parameters. + # if "values_changed" in diff and "root['resume']" in diff["values_changed"]: + # del diff["values_changed"]["root['resume']"] + # # Log a warning about differences between the checkpoint configuration and the provided + # # configuration. + # if len(diff) > 0: + # logging.warning( + # "At least one difference was detected between the checkpoint configuration and " + # f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration " + # "takes precedence.", + # ) + # # Use the checkpoint config instead of the provided config (but keep `resume` parameter). + # cfg = checkpoint_cfg + # cfg.resume = True + # elif Logger.get_last_checkpoint_dir(out_dir).exists(): + # raise RuntimeError( + # f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If " + # "you meant to resume training, please use `resume=true` in your command or yaml configuration." + # ) + + # if cfg.eval.batch_size > cfg.eval.n_episodes: + # raise ValueError( + # "The eval batch size is greater than the number of eval episodes " + # f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} " + # f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. " + # "This might significantly slow down evaluation. To fix this, you should update your command " + # f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), " + # f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)." + # ) # log metrics to terminal and wandb - logger = Logger(cfg, out_dir, wandb_job_name=job_name) + logger = Logger(cfg) - set_global_seed(cfg.seed) + if cfg.seed is not None: + set_global_seed(cfg.seed) # Check device is available device = get_safe_torch_device(cfg.device, log=True) @@ -666,4 +667,5 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa if __name__ == "__main__": - train_cli() + # train_cli() + train() diff --git a/poetry.lock b/poetry.lock index 8799e67ca..f8d1cf202 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1303,6 +1303,27 @@ files = [ [package.dependencies] pyarrow = "*" +[[package]] +name = "draccus" +version = "0.9.3" +description = "A slightly opinionated framework for simple dataclass-based configurations based on Pyrallis." +optional = false +python-versions = ">=3.8" +files = [ + {file = "draccus-0.9.3-py3-none-any.whl", hash = "sha256:04d3fe14d2b7d19290e6f7c76ff29fbfcc9b56e9e7b76d9439a18a26a1dbfe5e"}, + {file = "draccus-0.9.3.tar.gz", hash = "sha256:41db52347f5513deadfb8d512fed43bb41499ac5e63559530688c1d95a978043"}, +] + +[package.dependencies] +mergedeep = ">=1.3,<2.0" +pyyaml = ">=6.0,<7.0" +pyyaml-include = ">=1.4,<2.0" +toml = ">=0.10,<1.0" +typing-inspect = ">=0.9.0,<0.10.0" + +[package.extras] +dev = ["black", "mypy", "pre-commit", "pytest", "ruff"] + [[package]] name = "drawnow" version = "0.72.5" @@ -3505,6 +3526,17 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mergedeep" +version = "1.3.4" +description = "A deep merge function for 🐍." +optional = false +python-versions = ">=3.6" +files = [ + {file = "mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307"}, + {file = "mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8"}, +] + [[package]] name = "meshio" version = "5.3.5" @@ -3719,6 +3751,17 @@ files = [ [package.dependencies] dill = ">=0.3.8" +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nbclient" version = "0.10.0" @@ -5614,6 +5657,23 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "pyyaml-include" +version = "1.4.1" +description = "Extending PyYAML with a custom constructor for including YAML files within YAML files" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pyyaml-include-1.4.1.tar.gz", hash = "sha256:1a96e33a99a3e56235f5221273832464025f02ff3d8539309a3bf00dec624471"}, + {file = "pyyaml_include-1.4.1-py3-none-any.whl", hash = "sha256:323c7f3a19c82fbc4d73abbaab7ef4f793e146a13383866831631b26ccc7fb00"}, +] + +[package.dependencies] +PyYAML = ">=6.0,<7.0" + +[package.extras] +toml = ["toml"] + [[package]] name = "pyzmq" version = "26.2.0" @@ -6798,6 +6858,17 @@ webencodings = ">=0.4" doc = ["sphinx", "sphinx_rtd_theme"] test = ["pytest", "ruff"] +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + [[package]] name = "tomli" version = "2.0.2" @@ -7035,6 +7106,21 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "tzdata" version = "2024.2" @@ -7569,4 +7655,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8" +content-hash = "0ef550e648ac7eae32b18584d4facb7a83cf1e0ee9a6705daf89783fb56db8fb" diff --git a/pyproject.toml b/pyproject.toml index 59c2de8bc..2e7e1328f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} jsonlines = ">=4.0.0" +draccus = "^0.9.3" [tool.poetry.extras]