Skip to content

Commit

Permalink
make unit tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Apr 23, 2024
1 parent 42ed7bb commit 0660f71
Showing 13 changed files with 79 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/1_load_hugging_face_dataset.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
# TODO(rcadene): list available datasets on lerobot page using `datasets`

# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10

# display name of dataset and its features
# TODO(rcadene): update to make the print pretty
2 changes: 1 addition & 1 deletion examples/4_train_policy.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy.train()
policy.to(device)

8 changes: 6 additions & 2 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from pathlib import Path

import torch
from omegaconf import OmegaConf

DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None

@@ -43,7 +44,10 @@ def make_dataset(
)

if cfg.get("override_dataset_stats"):
for key, val in cfg.override_dataset_stats.items():
dataset.stats[key] = torch.tensor(val)
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

return dataset
2 changes: 1 addition & 1 deletion lerobot/common/envs/utils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ def preprocess_observation(observation):
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"

# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255

20 changes: 12 additions & 8 deletions lerobot/common/policies/act/configuration_act.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass
@@ -61,13 +61,17 @@ class ActionChunkingTransformerConfig:
n_action_steps: int = 100

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
unnormalize_output_modes: dict[str, str] = {
"action": "mean_std",
}
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"
12 changes: 9 additions & 3 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
to_buffer_dict,
unnormalize_outputs,
)

@@ -75,7 +76,7 @@ def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_s
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes

@@ -179,7 +180,12 @@ def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)

# TODO(rcadene): make _forward return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]

self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

@@ -214,7 +220,7 @@ def update(self, batch, **_) -> dict:

batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
loss = loss_dict["loss"]
loss.backward()

20 changes: 12 additions & 8 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass
@@ -70,13 +70,17 @@ class DiffusionConfig:
n_action_steps: int = 8

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes: dict[str, str] = {
"action": "min_max",
}
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)

# Architecture / modeling.
# Vision backbone.
11 changes: 8 additions & 3 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
get_dtype_from_parameters,
normalize_inputs,
populate_queues,
to_buffer_dict,
unnormalize_outputs,
)

@@ -57,7 +58,7 @@ def __init__(
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes

@@ -144,7 +145,11 @@ def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
else:
actions = self.diffusion.generate_actions(batch)

actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): make above methods return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]

self._queues["action"].extend(actions.transpose(0, 1))

action = self._queues["action"].popleft()
@@ -166,7 +171,7 @@ def update(self, batch: dict[str, Tensor], **_) -> dict:
loss = self.forward(batch)["loss"]
loss.backward()

# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)

grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
17 changes: 17 additions & 0 deletions lerobot/common/policies/utils.py
Original file line number Diff line number Diff line change
@@ -66,3 +66,20 @@ def unnormalize_outputs(batch, stats, unnormalize_output_modes):
else:
raise ValueError(mode)
return batch


def to_buffer_dict(dataset_stats):
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
# see: https://github.com/pytorch/pytorch/issues/37386
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
if dataset_stats is None:
return None

new_ds_stats = {}
for key, stats_dict in dataset_stats.items():
new_stats_dict = {}
for stats_type, value in stats_dict.items():
# set requires_grad=False to have the same behavior as a nn.Buffer
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
return nn.ParameterDict(new_ds_stats)
5 changes: 3 additions & 2 deletions lerobot/configs/policy/act.yaml
Original file line number Diff line number Diff line change
@@ -12,7 +12,8 @@ n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon

override_dataset_stats:
observation.image:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

@@ -35,7 +36,7 @@ policy:

# Normalization / Unnormalization
normalize_input_modes:
observation.image: mean_std
observation.images.top: mean_std
observation.state: mean_std
unnormalize_output_modes:
action: mean_std
3 changes: 3 additions & 0 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
@@ -19,9 +19,12 @@ online_steps: 0
offline_prioritized_sampler: true

override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
6 changes: 1 addition & 5 deletions lerobot/scripts/visualize_dataset.py
Original file line number Diff line number Diff line change
@@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
log_output_dir(out_dir)

logging.info("make_dataset")
dataset = make_dataset(
cfg,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
)
dataset = make_dataset(cfg)

logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
9 changes: 5 additions & 4 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@
from gymnasium.utils.env_checker import check_env

import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config
@@ -38,12 +37,14 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)

dataset = make_dataset(cfg)

env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs)
for key in dataset.image_keys:

# test image keys are float32 in range [0,1]
for key in obs:
if "image" not in key:
continue
img = obs[key]
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model

0 comments on commit 0660f71

Please sign in to comment.