Skip to content

Commit

Permalink
make load_state_dict work
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Apr 24, 2024
1 parent 0660f71 commit 72751b7
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 87 deletions.
36 changes: 32 additions & 4 deletions lerobot/common/policies/act/configuration_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@ class ActionChunkingTransformerConfig:
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
The key represents the input data name, and the value specifies the type of normalization to apply.
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
between -1 and 1).
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
Expand All @@ -51,6 +65,7 @@ class ActionChunkingTransformerConfig:
"""

# Environment.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 14
action_dim: int = 14

Expand All @@ -60,6 +75,18 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100
n_action_steps: int = 100

input_shapes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, str] = field(
default_factory=lambda: {
"action": [14],
}
)

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
Expand All @@ -72,6 +99,7 @@ class ActionChunkingTransformerConfig:
"action": "mean_std",
}
)

# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"
Expand Down
20 changes: 8 additions & 12 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from torchvision.ops.misc import FrozenBatchNorm2d

from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
to_buffer_dict,
unnormalize_outputs,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize


class ActionChunkingTransformerPolicy(nn.Module):
Expand Down Expand Up @@ -76,9 +72,10 @@ def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_s
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)

# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
Expand Down Expand Up @@ -174,17 +171,15 @@ def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""
self.eval()

batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)

if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps]

# TODO(rcadene): make _forward return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
actions = self.unnormalize_outputs({"action": actions})["action"]

self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
Expand Down Expand Up @@ -218,9 +213,10 @@ def update(self, batch, **_) -> dict:
start_time = time.time()
self.train()

batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)

loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()

Expand Down
35 changes: 31 additions & 4 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@ class DiffusionConfig:
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
The key represents the input data name, and the value specifies the type of normalization to apply.
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
between -1 and 1).
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
Expand Down Expand Up @@ -60,6 +74,7 @@ class DiffusionConfig:

# Environment.
# Inherit these from the environment config.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
Expand All @@ -69,6 +84,18 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8

input_shapes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, str] = field(
default_factory=lambda: {
"action": [2],
}
)

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
Expand Down
17 changes: 7 additions & 10 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@
from torch.nn.modules.batchnorm import _BatchNorm

from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
normalize_inputs,
populate_queues,
to_buffer_dict,
unnormalize_outputs,
)


Expand All @@ -58,9 +56,10 @@ def __init__(
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)

# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
Expand Down Expand Up @@ -133,7 +132,7 @@ def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
assert "observation.state" in batch
assert len(batch) == 2

batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)

self._queues = populate_queues(self._queues, batch)

Expand All @@ -146,9 +145,7 @@ def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
actions = self.diffusion.generate_actions(batch)

# TODO(rcadene): make above methods return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
actions = self.unnormalize_outputs({"action": actions})["action"]

self._queues["action"].extend(actions.transpose(0, 1))

Expand All @@ -166,12 +163,12 @@ def update(self, batch: dict[str, Tensor], **_) -> dict:

self.diffusion.train()

batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)

loss = self.forward(batch)["loss"]
loss.backward()

# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): self.unnormalize_outputs(out_dict)

grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
Expand Down
Loading

0 comments on commit 72751b7

Please sign in to comment.