Skip to content

Commit

Permalink
Add draccus, create MainConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Dec 5, 2024
1 parent 32eb0ce commit 3bb5876
Show file tree
Hide file tree
Showing 15 changed files with 715 additions and 148 deletions.
114 changes: 62 additions & 52 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,46 @@
import logging

import torch
from omegaconf import ListConfig, OmegaConf

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.configs.default import MainConfig
from lerobot.configs.policies import PretrainedConfig

IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}

def resolve_delta_timestamps(cfg):

def resolve_delta_timestamps(
cfg: PretrainedConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
"""Resolves delta_timestamps config key (in-place) by using `eval`.
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
the data type of its values).
"""
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
delta_timestamps = {}
for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]

if len(delta_timestamps) == 0:
delta_timestamps = None

return delta_timestamps


def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
def make_dataset(cfg: MainConfig, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
Expand All @@ -50,67 +69,58 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)

# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora":
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
if cfg.env.type != "dora":
if isinstance(cfg.dataset.repo_id, str):
dataset_repo_ids = [cfg.dataset.repo_id] # single dataset
elif isinstance(cfg.dataset.repo_id, list):
dataset_repo_ids = cfg.dataset.repo_id # multiple datasets
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
raise ValueError(
"Expected cfg.dataset.repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)

for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
if cfg.env.type not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
f"environment ({cfg.env.type=})."
)

resolve_delta_timestamps(cfg)

image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)

if isinstance(cfg.dataset_repo_id, str):
if isinstance(cfg.dataset.repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset_repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"),
cfg.dataset.repo_id,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.video_backend,
local_files_only=cfg.dataset.local_files_only,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"),
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.video_backend,
)

if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
# for key, stats_dict in cfg.override_dataset_stats.items():
# for stats_type, listconfig in stats_dict.items():
# # example of stats_type: min, max, mean, std
# stats = OmegaConf.to_container(listconfig, resolve=True)
# dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

return dataset
115 changes: 115 additions & 0 deletions lerobot/common/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Sequence

import torch
Expand Down Expand Up @@ -137,6 +138,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)


# TODO(aliberts): Remove
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
Expand Down Expand Up @@ -195,3 +197,116 @@ def check_value(name, weight, min_max):
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)


@dataclass
class ImageTransformConfig:
"""
For each transform, the following parameters are available:
weight: This represents the multinomial probability (with no replacement)
used for sampling the transform. If the sum of the weights is not 1,
they will be normalized.
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
custom transform defined here.
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
(following uniform distribution) when it's applied.
"""

weight: int = 1.0
type: str = "Identity"
kwargs: Dict[str, Any] = field(default_factory=dict)


@dataclass
class ImageTransformsConfig:
"""
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""

# Set this flag to `true` to enable transforms during training
enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: list[ImageTransformConfig] = field(
default_factory=lambda: [
ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
]
)


class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""

_registry = {
"Identity": v2.Identity,
"ColorJitter": v2.ColorJitter,
"SharpnessJitter": SharpnessJitter,
}

def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg

weights = []
transforms = []
for tf_cfg in cfg.tfs:
if tf_cfg.weight <= 0.0:
continue

transform_cls = self._registry.get(tf_cfg.type)
if transform_cls is None:
available_transforms = ", ".join(self._registry.keys())
raise ValueError(
f"Transform '{tf_cfg.type}' not found in the registry. "
f"Available transforms are: {available_transforms}"
)

# Instantiate the transform
transform_instance = transform_cls(**tf_cfg.kwargs)
transforms.append(transform_instance)
weights.append(tf_cfg.weight)

n_subset = min(len(transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.transform = v2.Identity()
else:
self.transform = RandomSubsetApply(
transforms=transforms,
p=weights,
n_subset=n_subset,
random_order=cfg.random_order,
)

def forward(self, *inputs: Any) -> Any:
return self.transform(*inputs)
1 change: 1 addition & 0 deletions lerobot/common/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .configs import AlohaEnv, EnvConfig, RealEnv # noqa: F401
30 changes: 30 additions & 0 deletions lerobot/common/envs/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from dataclasses import dataclass

import draccus


@dataclass
class EnvConfig(draccus.ChoiceRegistry):
task: str | None = None
state_dim: int = 18
action_dim: int = 18
fps: int = 30

@property
def type(self) -> str:
return self.get_choice_name(self.__class__)


@EnvConfig.register_subclass("real_world")
@dataclass
class RealEnv(EnvConfig):
pass


@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
state_dim: int = 14
action_dim: int = 14
fps: int = 50
38 changes: 16 additions & 22 deletions lerobot/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
from lerobot.configs.default import MainConfig


def log_output_dir(out_dir):
Expand All @@ -42,9 +43,9 @@ def log_output_dir(out_dir):
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}",
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"env:{cfg.env.type}",
f"seed:{cfg.seed}",
]
return lst if return_list else "-".join(lst)
Expand Down Expand Up @@ -83,25 +84,18 @@ class Logger:
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"

def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
def __init__(self, cfg: MainConfig):
self._cfg = cfg
self.log_dir = Path(log_dir)
self.log_dir = cfg.dir
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
self.job_name = cfg.job_name
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)

# Set up WandB.
self._group = cfg_to_group(cfg)
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
run_offline = not cfg.wandb.enable or not cfg.wandb.project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
Expand All @@ -115,12 +109,12 @@ def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = N

wandb.init(
id=wandb_run_id,
project=project,
entity=entity,
name=wandb_job_name,
notes=cfg.get("wandb", {}).get("notes"),
project=cfg.wandb.project,
entity=cfg.wandb.entity,
name=self.job_name,
notes=cfg.wandb.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=log_dir,
dir=self.log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
# TODO(rcadene): try set to True
save_code=False,
Expand Down
Loading

0 comments on commit 3bb5876

Please sign in to comment.