diff --git a/lerobot/common/policies/actlanguage/configuration_act.py b/lerobot/common/policies/actlanguage/configuration_act.py new file mode 100644 index 00000000000..d443b79ee9d --- /dev/null +++ b/lerobot/common/policies/actlanguage/configuration_act.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("actlanguage") +@dataclass +class ACTLanguageConfig(PreTrainedConfig): + """Configuration class for the Action Chunking Transformers policy. + + Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and 'output_shapes`. + + Notes on the inputs and outputs: + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. + - If there are multiple keys beginning with "observation.images." they are treated as multiple camera + views. Right now we only support all images having the same shape. + - May optionally work without an "observation.state" key for the proprioceptive robot state. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + chunk_size: The size of the action prediction "chunks" in units of environment steps. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + This should be no greater than the chunk size. For example, if the chunk size size 100, you may + set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the + environment, and throws the other 50 out. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. + `None` means no pretrained weights. + replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated + convolution. + pre_norm: Whether to use "pre-norm" in the transformer blocks. + dim_model: The transformer blocks' main hidden dimension. + n_heads: The number of heads to use in the transformer blocks' multi-head attention. + dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward + layers. + feedforward_activation: The activation to use in the transformer block's feed-forward layers. + n_encoder_layers: The number of transformer layers to use for the transformer encoder. + n_decoder_layers: The number of transformer layers to use for the transformer decoder. + use_vae: Whether to use a variational objective during training. This introduces another transformer + which is used as the VAE's encoder (not to be confused with the transformer encoder - see + documentation in the policy class). + latent_dim: The VAE's latent dimension. + n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. + temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal + ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be + 1 when using this feature, as inference needs to happen at every step to form an ensemble. For + more information on how ensembling works, please see `ACTTemporalEnsembler`. + dropout: Dropout to use in the transformer layers (see code for details). + kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective + is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. + """ + + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 100 + n_action_steps: int = 100 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Architecture. + # Vision backbone. + vision_backbone: str = "resnet18" + pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" + replace_final_stride_with_dilation: int = False + # Transformer layers. + pre_norm: bool = False + dim_model: int = 512 + n_heads: int = 8 + dim_feedforward: int = 3200 + feedforward_activation: str = "relu" + n_encoder_layers: int = 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: int = 1 + # VAE. + use_vae: bool = True + latent_dim: int = 32 + n_vae_encoder_layers: int = 4 + + # Inference. + # Note: the value used in ACT when temporal ensembling is enabled is 0.01. + temporal_ensemble_coeff: float | None = None + + # Training and loss computation. + dropout: float = 0.1 + kl_weight: float = 10.0 + + # Training preset + optimizer_lr: float = 1e-5 + optimizer_weight_decay: float = 1e-4 + optimizer_lr_backbone: float = 1e-5 + task_vocab_size: int = 320000 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1: + raise NotImplementedError( + "`n_action_steps` must be 1 when using temporal ensembling. This is " + "because the policy needs to be queried every step to compute the ensembled action." + ) + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + if not self.image_features and not self.env_state_feature: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/actlanguage/modeling_act.py b/lerobot/common/policies/actlanguage/modeling_act.py new file mode 100644 index 00000000000..245a3e9642e --- /dev/null +++ b/lerobot/common/policies/actlanguage/modeling_act.py @@ -0,0 +1,1054 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS +# OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations under the License. +"""Action Chunking Transformer Policy + +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. +""" + +import math +from collections import deque +from itertools import chain +from typing import Callable, Optional + +import einops +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d + +from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from transformers import AutoTokenizer + + +class ACTLanguagePolicy(PreTrainedPolicy): + """ + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + """ + + config_class = ACTConfig + name = "act" + + def __init__( + self, + config: ACTConfig, + dataset_stats: Optional[dict[str, dict[str, Tensor]]] = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.tokenizer = AutoTokenizer.from_pretrained("lerobot/pi0") + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = ACT(config) + + if config.temporal_ensemble_coeff is not None: + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.reset() + + def get_optim_params(self) -> dict: + """ + Returns optimizer parameters for the policy. + + This method groups parameters into different optimizer parameter groups, + separating those belonging to the model backbone with a distinct learning rate. + + Returns: + A list of dictionaries specifying parameter groups and their optimizer settings. + """ + return [ + { + "params": [ + p for n, p in self.named_parameters() + if not n.startswith("model.backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in self.named_parameters() + if n.startswith("model.backbone") and p.requires_grad + ], + "lr": self.config.optimizer_lr_backbone, + }, + ] + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.config.temporal_ensemble_coeff is not None: + self.temporal_ensembler.reset() + else: + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + Manages action queue and delegates to appropriate sub-methods for: + 1. Input normalization + 2. Task tokenization + 3. Image feature handling + 4. Temporal ensembling + 5. Action queue management + + Args: + batch: Input observations including task descriptions and environment states + + Returns: + Selected action tensor + """ + self.eval() + + batch = self._normalize_inputs(batch) + self._tokenize_tasks(batch) + self._prepare_image_features(batch) + + if self.config.temporal_ensemble_coeff is not None: + return self._apply_temporal_ensembling(batch) + + return self._get_queued_action(batch) + + def _normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Normalize input observations using configured normalization parameters. + + Args: + batch: Raw input observations + + Returns: + Normalized input batch + """ + return self.normalize_inputs(batch) + + def _tokenize_tasks(self, batch: dict[str, Tensor]) -> None: + """Tokenize task descriptions in the input batch. + + Adds 'task.input_ids' and 'task.attention_mask' to the batch. + + Args: + batch: Input batch containing 'task' key with text descriptions + """ + task_texts = batch["task"] + tokenized = self.tokenizer( + task_texts, + padding=True, + return_tensors="pt", + return_attention_mask=True + ) + batch["task.input_ids"] = tokenized.input_ids + batch["task.attention_mask"] = tokenized.attention_mask + + def _prepare_image_features(self, batch: dict[str, Tensor]) -> None: + """Prepare image features from configured input keys. + + Creates 'observation.images' list from specified image feature keys. + + Args: + batch: Input batch containing image feature tensors + """ + if self.config.image_features: + batch["observation.images"] = [ + batch[key] for key in self.config.image_features + ] + + def _apply_temporal_ensembling(self, batch: dict[str, Tensor]) -> Tensor: + """Apply temporal ensembling to model outputs. + + Args: + batch: Processed input batch + + Returns: + Temporally ensembled action tensor + """ + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({"action": actions})["action"] + return self.temporal_ensembler.update(actions) + + def _get_queued_action(self, batch: dict[str, Tensor]) -> Tensor: + """Retrieve action from queue or generate new actions if queue is empty. + + Args: + batch: Processed input batch + + Returns: + Next action from the queue + """ + if not self._action_queue: + self._refill_action_queue(batch) + return self._action_queue.popleft() + + def _refill_action_queue(self, batch: dict[str, Tensor]) -> None: + """Generate and queue multiple actions when queue is empty. + + Args: + batch: Processed input batch + """ + actions = self.model(batch)[0][:, :self.config.n_action_steps] + actions = self.unnormalize_outputs({"action": actions})["action"] + self._action_queue.extend(actions.transpose(0, 1)) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Optional[tuple[Tensor, Tensor]]]: + """Run the batch through the model and compute the loss for training/validation. + + Args: + batch: Input batch containing observations, actions, and padding masks + + Returns: + Tuple containing total loss and a dictionary of loss components + """ + batch = self._prepare_input_batch(batch) + batch = self.normalize_targets(batch) + + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + l1_loss = self._compute_l1_loss(batch, actions_hat) + + loss_dict = {"l1_loss": l1_loss.item()} + total_loss = l1_loss + + if self.config.use_vae: + kld_loss = self._compute_kl_divergence(mu_hat, log_sigma_x2_hat) + loss_dict["kld_loss"] = kld_loss.item() + total_loss += kld_loss * self.config.kl_weight + + return total_loss, loss_dict + + def _prepare_input_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare input batch by normalizing inputs and formatting image features. + + Args: + batch: Raw input batch + + Returns: + Processed batch with normalized inputs and formatted image features + """ + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # Create shallow copy to avoid modifying original + batch["observation.images"] = [ + batch[key] for key in self.config.image_features + ] + return batch + + def _compute_l1_loss(self, batch: dict[str, Tensor], actions_hat: Tensor) -> Tensor: + """Calculate L1 loss between predicted and ground truth actions. + + Args: + batch: Processed input batch containing 'action' and 'action_is_pad' + actions_hat: Predicted actions from the model + + Returns: + L1 loss tensor with padding mask applied + """ + return ( + F.l1_loss(batch["action"], actions_hat, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + def _compute_kl_divergence(self, mu_hat: Tensor, log_sigma_x2_hat: Tensor) -> Tensor: + """Calculate KL divergence loss for VAE regularization. + + Args: + mu_hat: Predicted mean values + log_sigma_x2_hat: Predicted log variance values + + Returns: + KL divergence loss tensor + """ + return ( + -0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - log_sigma_x2_hat.exp()) + ).sum(-1).mean() + + +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1. + + Args: + temporal_ensemble_coeff: Coefficient controlling the weight decay. + chunk_size: The number of actions in a chunk. + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.reset() + + def reset(self): + """Reset the ensembling variables.""" + self.ensembled_actions = None + self.ensembled_actions_count = None + + def update(self, actions: Tensor) -> Tensor: + """ + Update the temporal ensemble with new actions and return the next ensembled action. + + Args: + actions: Tensor of shape (batch, chunk_size, action_dim) with new actions. + + Returns: + Tensor with the ensembled action of shape (batch, action_dim) + """ + self._move_weights_to_device(actions.device) + + if self.ensembled_actions is None: + self._initialize_ensemble(actions) + else: + self._update_existing_ensemble(actions) + + return self._prepare_next_action() + + def _move_weights_to_device(self, device: torch.device) -> None: + """Move ensemble weights to the specified device.""" + self.ensemble_weights = self.ensemble_weights.to(device=device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=device) + + def _initialize_ensemble(self, actions: Tensor) -> None: + """Initialize ensemble storage for first update.""" + self.ensembled_actions = actions.clone() + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), + dtype=torch.long, + device=actions.device + ) + + def _update_existing_ensemble(self, actions: Tensor) -> None: + """Update existing ensemble with new actions using exponential weighting.""" + # Apply exponential moving average update + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + + # Update counters and append new action + self.ensembled_actions_count = torch.clamp( + self.ensembled_actions_count + 1, + max=self.chunk_size + ) + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat([ + self.ensembled_actions_count, + torch.ones_like(self.ensembled_actions_count[-1:]) + ]) + + def _prepare_next_action(self) -> Tensor: + """Extract next action and update ensemble storage.""" + action = self.ensembled_actions[:, 0] + self.ensembled_actions = self.ensembled_actions[:, 1:] + self.ensembled_actions_count = self.ensembled_actions_count[1:] + return action + + +class ACT(nn.Module): + """Action Chunking Transformer: The underlying neural network for ACTPolicy. + + This model includes an optional VAE encoder, task embeddings, image feature backbone, transformer encoder/decoder, + and a final regression head for action prediction. + """ + + def __init__(self, config: ACTConfig): + """ + Initializes the Action Chunking Transformer model. + + Args: + config: Configuration object containing model parameters + """ + super().__init__() + self.config = config + + self._init_vae_components() + self._init_task_embedding() + self._init_image_backbone() + self._init_encoder_decoder() + self._init_input_projections() + self._init_position_embeddings() + self._init_action_head() + self._reset_parameters() + + def _init_vae_components(self) -> None: + """Initialize VAE-specific components if enabled in configuration.""" + if not self.config.use_vae: + return + + self.vae_encoder = ACTEncoder(self.config, is_vae_encoder=True) + self.vae_encoder_cls_embed = nn.Embedding(1, self.config.dim_model) + + if self.config.robot_state_feature: + self.vae_encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], self.config.dim_model + ) + + self.vae_encoder_action_input_proj = nn.Linear( + self.config.action_feature.shape[0], self.config.dim_model + ) + self.vae_encoder_latent_output_proj = nn.Linear( + self.config.dim_model, self.config.latent_dim * 2 + ) + + num_tokens = 1 + self.config.chunk_size # CLS token + action tokens + if self.config.robot_state_feature: + num_tokens += 1 + self.register_buffer( + "vae_encoder_pos_enc", + create_sinusoidal_pos_embedding(num_tokens, self.config.dim_model).unsqueeze(0), + ) + + def _init_task_embedding(self) -> None: + """Initialize task embedding layer.""" + self.task_embedding = nn.Embedding( + self.config.task_vocab_size, + self.config.dim_model + ) + + def _init_image_backbone(self) -> None: + """Initialize vision backbone if image features are enabled.""" + if not self.config.image_features: + return + + backbone_model = getattr(torchvision.models, self.config.vision_backbone)( + replace_stride_with_dilation=[False, False, self.config.replace_final_stride_with_dilation], + weights=self.config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + def _init_encoder_decoder(self) -> None: + """Initialize main encoder and decoder components.""" + self.encoder = ACTEncoder(self.config) + self.decoder = ACTDecoder(self.config) + + def _init_input_projections(self) -> None: + """Initialize input projection layers for different modalities.""" + if self.config.robot_state_feature: + self.encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], self.config.dim_model + ) + + if self.config.env_state_feature: + self.encoder_env_state_input_proj = nn.Linear( + self.config.env_state_feature.shape[0], self.config.dim_model + ) + + self.encoder_latent_input_proj = nn.Linear( + self.config.latent_dim, self.config.dim_model + ) + + if self.config.image_features: + backbone_out_features = getattr(torchvision.models, self.config.vision_backbone)().fc.in_features + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_out_features, self.config.dim_model, kernel_size=1 + ) + + def _init_position_embeddings(self) -> None: + """Initialize positional embeddings for different components.""" + # Encoder 1D feature positional embedding + n_1d_tokens = 1 # Task embedding + if self.config.robot_state_feature: + n_1d_tokens += 1 + if self.config.env_state_feature: + n_1d_tokens += 1 + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, self.config.dim_model) + + # Image feature positional embedding + if self.config.image_features: + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(self.config.dim_model // 2) + + # Decoder positional embedding + self.decoder_pos_embed = nn.Embedding(self.config.chunk_size, self.config.dim_model) + + def _init_action_head(self) -> None: + """Initialize final action prediction layer.""" + self.action_head = nn.Linear( + self.config.dim_model, self.config.action_feature.shape[0] + ) + + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Optional[tuple[Tensor, Tensor]]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder).""" + batch = self._prepare_batch(batch) + batch_size = self._get_batch_size(batch) + latent_sample, mu, log_sigma_x2 = self._prepare_latent(batch, batch_size) + encoder_in_tokens, encoder_in_pos_embed = self._prepare_encoder_inputs(batch, latent_sample, batch_size) + encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) + decoder_in = self._prepare_decoder_input(batch_size, encoder_in_pos_embed.dtype, encoder_in_pos_embed.device) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=encoder_in_pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + decoder_out = decoder_out.transpose(0, 1) + actions = self.action_head(decoder_out) + return actions, (mu, log_sigma_x2) + + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare and normalize the batch input and add image features if available.""" + if self.config.image_features: + batch = dict(batch) + batch["observation.images"] = [batch[key] for key in self.config.image_features] + return batch + + def _get_batch_size(self, batch: dict[str, Tensor]) -> int: + """Extract the batch size from the batch dictionary.""" + if "observation.images" in batch: + return batch["observation.images"][0].shape[0] + return batch["observation.environment_state"].shape[0] + + def _prepare_latent(self, batch: dict[str, Tensor], batch_size: int) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Prepare latent token for transformer encoder with VAE handling. + + Args: + batch: Input batch containing observations and actions + batch_size: Batch size of the input data + + Returns: + latent_sample: Sampled latent representation (batch_size, latent_dim) + mu: Latent distribution mean (or None if not using VAE) + log_sigma_x2: Latent log variance (or None if not using VAE) + """ + if not self._should_use_vae(batch): + return self._create_dummy_latent(batch_size), None, None + + vae_input = self._create_vae_input_embeddings(batch, batch_size) + pos_embed, key_padding_mask = self._prepare_vae_positioning(batch, batch_size) + + cls_token_out = self._run_vae_encoder(vae_input, pos_embed, key_padding_mask) + mu, log_sigma_x2 = self._extract_latent_parameters(cls_token_out) + latent_sample = self._sample_latent_distribution(mu, log_sigma_x2) + + return latent_sample, mu, log_sigma_x2 + + def _should_use_vae(self, batch: dict[str, Tensor]) -> bool: + """Check if VAE should be used for current batch.""" + return self.config.use_vae and "action" in batch + + def _create_dummy_latent(self, batch_size: int) -> Tensor: + """Create zero-initialized latent tensor for non-VAE cases.""" + return torch.zeros( + (batch_size, self.config.latent_dim), + dtype=torch.float32, + device=self._get_device() + ) + + def _create_vae_input_embeddings(self, batch: dict[str, Tensor], batch_size: int) -> Tensor: + """Construct VAE encoder input embeddings.""" + components = [self._get_cls_embedding(batch_size)] + + if self.config.robot_state_feature: + components.append(self._get_robot_state_embedding(batch)) + + components.append(self._get_action_embedding(batch)) + return torch.cat(components, dim=1) + + def _get_cls_embedding(self, batch_size: int) -> Tensor: + """Generate CLS token embedding.""" + return einops.repeat( + self.vae_encoder_cls_embed.weight, + "1 d -> b 1 d", + b=batch_size + ) + + def _get_robot_state_embedding(self, batch: dict[str, Tensor]) -> Tensor: + """Generate robot state embedding if configured.""" + return self.vae_encoder_robot_state_input_proj( + batch["observation.state"] + ).unsqueeze(1) + + def _get_action_embedding(self, batch: dict[str, Tensor]) -> Tensor: + """Generate action sequence embeddings.""" + return self.vae_encoder_action_input_proj(batch["action"]) + + def _prepare_vae_positioning(self, batch: dict[str, Tensor], batch_size: int) -> tuple[Tensor, Tensor]: + """Prepare positional embeddings and attention mask for VAE encoder.""" + pos_embed = self.vae_encoder_pos_enc.expand(batch_size, -1, -1) + return pos_embed.permute(1, 0, 2), self._create_key_padding_mask(batch, batch_size) + + def _create_key_padding_mask(self, batch: dict[str, Tensor], batch_size: int) -> Tensor: + """Create attention mask for padded action sequences.""" + cls_joint_mask = torch.full( + (batch_size, self._get_cls_joint_size()), + False, + device=self._get_device() + ) + return torch.cat([cls_joint_mask, batch["action_is_pad"]], dim=1) + + def _get_cls_joint_size(self) -> int: + """Calculate CLS joint size based on configuration.""" + return 2 if self.config.robot_state_feature else 1 + + def _run_vae_encoder(self, vae_input: Tensor, pos_embed: Tensor, mask: Tensor) -> Tensor: + """Execute VAE encoder forward pass.""" + return self.vae_encoder( + vae_input.permute(1, 0, 2), + pos_embed=pos_embed, + key_padding_mask=mask + )[0] + + def _extract_latent_parameters(self, cls_token_out: Tensor) -> tuple[Tensor, Tensor]: + """Extract mean and log variance from encoder output.""" + params = self.vae_encoder_latent_output_proj(cls_token_out) + return params.chunk(2, dim=-1) + + def _sample_latent_distribution(self, mu: Tensor, log_sigma_x2: Tensor) -> Tensor: + """Sample latent variable using reparameterization trick.""" + return mu + (log_sigma_x2 / 2).exp() * torch.randn_like(mu) + + def _get_device(self) -> torch.device: + """Get current device from model parameters.""" + return next(self.parameters()).device + + def _prepare_encoder_inputs(self, batch: dict[str, Tensor], latent_sample: Tensor, batch_size: int) -> tuple[Tensor, Tensor]: + """Prepare input tokens and positional embeddings for transformer encoder. + + Args: + batch: Input batch containing observations and task information + latent_sample: Latent representation from VAE + batch_size: Batch size of the input data + + Returns: + encoder_in_tokens: Combined input tokens (total_tokens, batch_size, dim_model) + encoder_in_pos_embed: Corresponding positional embeddings (total_tokens, 1, dim_model) + """ + tokens, pos_embeds = [], [] + + self._add_latent_token(tokens, pos_embeds, latent_sample) + self._add_robot_state_token(tokens, pos_embeds, batch) + self._add_task_tokens(tokens, pos_embeds, batch) + self._add_env_state_token(tokens, batch) + self._add_image_tokens(tokens, pos_embeds, batch) + + tokens = self._adjust_batch_dimensions(tokens, batch_size) + pos_embeds = self._adjust_batch_dimensions(pos_embeds, batch_size) + + return torch.cat(tokens, dim=0), torch.cat(pos_embeds, dim=0) + + def _add_latent_token(self, tokens: list[Tensor], pos_embeds: list[Tensor], latent_sample: Tensor) -> None: + """Add latent token and its positional embedding.""" + latent_token = self.encoder_latent_input_proj(latent_sample).unsqueeze(0) + tokens.append(latent_token) + pos_embeds.append(self._get_1d_pos_embedding(0)) + + def _add_robot_state_token(self, tokens: list[Tensor], pos_embeds: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add robot state token if configured.""" + if not self.config.robot_state_feature: + return + robot_token = self.encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(0) + tokens.append(robot_token) + pos_embeds.append(self._get_1d_pos_embedding(1)) + + def _add_task_tokens(self, tokens: list[Tensor], pos_embeds: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add task tokens and their positional embeddings if present.""" + if "task.input_ids" not in batch: + return + task_embeds = self.task_embedding(batch["task.input_ids"]) + task_embeds = einops.rearrange(task_embeds, "b s d -> s b d") + tokens.append(task_embeds) + task_pos = self._create_task_pos_embedding(task_embeds) + pos_embeds.append(task_pos) + + def _add_env_state_token(self, tokens: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add environment state token if configured.""" + if not self.config.env_state_feature: + return + env_token = self.encoder_env_state_input_proj(batch["observation.environment_state"]).unsqueeze(0) + tokens.append(env_token) + + def _add_image_tokens(self, tokens: list[Tensor], pos_embeds: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add image tokens and positional embeddings if configured.""" + if not self.config.image_features: + return + for img in batch["observation.images"]: + features = self.backbone(img)["feature_map"] + pos_embed = self.encoder_cam_feat_pos_embed(features).to(dtype=features.dtype) + features = self._process_image_features(features) + pos_embed = self._process_image_pos_embed(pos_embed) + tokens.append(features) + pos_embeds.append(pos_embed) + + def _process_image_features(self, features: Tensor) -> Tensor: + """Project and reshape image features for encoder input.""" + features = self.encoder_img_feat_input_proj(features) + return einops.rearrange(features, "b c h w -> (h w) b c") + + def _process_image_pos_embed(self, pos_embed: Tensor) -> Tensor: + """Reshape positional embeddings for image features.""" + return einops.rearrange(pos_embed, "b c h w -> (h w) b c") + + def _adjust_batch_dimensions(self, tensors: list[Tensor], batch_size: int) -> list[Tensor]: + """Ensure all tensors have correct batch dimension.""" + return [self._adjust_tensor(t, batch_size) for t in tensors] + + def _adjust_tensor(self, tensor: Tensor, batch_size: int) -> Tensor: + """Adjust batch dimension of a single tensor.""" + if tensor.shape[1] == batch_size: + return tensor + if tensor.shape[1] == 1: + return tensor.expand(-1, batch_size, -1) + raise ValueError(f"Unexpected batch size {tensor.shape[1]}, expected {batch_size}") + + def _get_1d_pos_embedding(self, index: int) -> Tensor: + """Retrieve 1D positional embedding by index.""" + return self.encoder_1d_feature_pos_embed.weight[index].unsqueeze(0).unsqueeze(1) + + def _create_task_pos_embedding(self, task_embeds: Tensor) -> Tensor: + """Create sinusoidal positional embeddings for task tokens.""" + seq_len = task_embeds.size(0) + pos_embed = create_sinusoidal_pos_embedding(seq_len, self.config.dim_model) + return pos_embed.unsqueeze(1).to(device=task_embeds.device) + + def _prepare_decoder_input(self, batch_size: int, dtype, device) -> Tensor: + """Prepare the zero-initialized input for the decoder. + + Returns: + Tensor of shape (chunk_size, batch_size, dim_model) + """ + return torch.zeros( + (self.config.chunk_size, batch_size, self.config.dim_model), + dtype=dtype, + device=device, + ) + + +class ACTEncoder(nn.Module): + """Convenience module for running multiple encoder layers, optionally with normalization.""" + + def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): + """ + Initializes the ACTEncoder module. + + Args: + config: An instance of ACTConfig with model configuration parameters. + is_vae_encoder: Flag indicating whether the encoder is used as a VAE encoder. + """ + super().__init__() + self.is_vae_encoder = is_vae_encoder + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() + + def forward(self, tensor: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Pass the input through encoder layers and apply normalization. + + Args: + tensor: Input tensor of shape (seq_length, batch_size, embedding_dim). + pos_embed: Optional positional embeddings. + key_padding_mask: Optional mask for padded tokens. + Returns: + Encoded tensor. + """ + for layer in self.layers: + tensor = layer(tensor, pos_embed=pos_embed, key_padding_mask=key_padding_mask) + return self.norm(tensor) + + +class ACTEncoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + """ + Initializes an ACTEncoderLayer with self-attention and feed-forward network. + + Args: + config: An instance of ACTConfig containing layer configuration parameters. + """ + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def _apply_self_attention(self, tensor: Tensor, pos_embed: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: + """ + Applies self-attention to the input tensor. + + Args: + tensor: Input tensor. + pos_embed: Optional positional embeddings. + key_padding_mask: Optional mask for padded tokens. + Returns: + Tensor after self-attention and dropout. + """ + norm_tensor = self.norm1(tensor) if self.pre_norm else tensor + if pos_embed is not None: + norm_tensor = norm_tensor + pos_embed + attn_output = self.self_attn(norm_tensor, norm_tensor, value=norm_tensor, key_padding_mask=key_padding_mask)[0] + return tensor + self.dropout1(attn_output) + + def _apply_feed_forward(self, tensor: Tensor) -> Tensor: + """ + Applies the feed-forward network. + + Args: + tensor: Input tensor. + Returns: + Tensor after feed-forward processing. + """ + ff_output = self.linear2(self.dropout(self.activation(self.linear1(tensor)))) + return ff_output + + def forward(self, input_tensor: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Forward pass for the encoder layer. + + Args: + input_tensor: Input tensor of shape (seq_length, batch_size, embedding_dim). + pos_embed: Optional positional embeddings. + key_padding_mask: Optional padding mask. + Returns: + Output tensor. + """ + residual = input_tensor + attn_out = self._apply_self_attention(input_tensor, pos_embed, key_padding_mask) + # Apply normalization after self-attention + if self.pre_norm: + intermediate = self.norm2(attn_out) + else: + intermediate = self.norm1(attn_out) + ff_out = self._apply_feed_forward(intermediate) + output = residual + self.dropout2(ff_out) + if not self.pre_norm: + output = self.norm2(output) + return output + + +class ACTDecoder(nn.Module): + def __init__(self, config: ACTConfig): + """Convenience module for running multiple decoder layers followed by normalization.""" + super().__init__() + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) + + def forward( + self, + tensor: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Optional[Tensor] = None, + encoder_pos_embed: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through decoder layers and apply normalization. + + Args: + tensor: Input tensor of shape (decoder_seq_length, batch_size, embedding_dim). + encoder_out: Encoder output tensor. + decoder_pos_embed: Optional decoder positional embeddings. + encoder_pos_embed: Optional encoder positional embeddings. + Returns: + Decoded tensor. + """ + for layer in self.layers: + tensor = layer(tensor, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed) + return self.norm(tensor) if self.norm is not None else tensor + + +class ACTDecoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + """ + Initializes an ACTDecoderLayer with self-attention, cross-attention, and feed-forward network. + + Args: + config: An instance of ACTConfig containing layer configuration parameters. + """ + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.norm3 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.dropout3 = nn.Dropout(config.dropout) + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def _apply_self_attention(self, tensor: Tensor, decoder_pos_embed: Optional[Tensor]) -> Tensor: + """ + Applies self-attention for the decoder. + + Args: + tensor: Input tensor. + decoder_pos_embed: Optional positional embeddings. + Returns: + Tensor after self-attention. + """ + norm_tensor = self.norm1(tensor) if self.pre_norm else tensor + if decoder_pos_embed is not None: + norm_tensor = norm_tensor + decoder_pos_embed + attn_output = self.self_attn(norm_tensor, norm_tensor, value=norm_tensor)[0] + return tensor + self.dropout1(attn_output) + + def _apply_cross_attention( + self, tensor: Tensor, encoder_output: Tensor, decoder_pos_embed: Optional[Tensor], encoder_pos_embed: Optional[Tensor] + ) -> Tensor: + """ + Applies cross-attention between decoder input and encoder output. + + Args: + tensor: Decoder input tensor. + encoder_output: Encoder output tensor. + decoder_pos_embed: Optional decoder positional embeddings. + encoder_pos_embed: Optional encoder positional embeddings. + Returns: + Tensor after cross-attention. + """ + query = tensor if decoder_pos_embed is None else tensor + decoder_pos_embed + key = encoder_output if encoder_pos_embed is None else encoder_output + encoder_pos_embed + cross_attn_out = self.multihead_attn(query=query, key=key, value=encoder_output)[0] + return tensor + self.dropout2(cross_attn_out) + + def _apply_feed_forward(self, tensor: Tensor) -> Tensor: + """ + Applies the feed-forward network. + + Args: + tensor: Input tensor. + Returns: + Tensor after feed-forward processing. + """ + ff_out = self.linear2(self.dropout(self.activation(self.linear1(tensor)))) + return ff_out + + def forward( + self, + input_tensor: Tensor, + encoder_output: Tensor, + decoder_pos_embed: Optional[Tensor] = None, + encoder_pos_embed: Optional[Tensor] = None, + ) -> Tensor: + """ + Forward pass for the decoder layer. + + Args: + input_tensor: Input tensor (decoder sequence, batch, channel). + encoder_output: Encoder output tensor. + decoder_pos_embed: Optional decoder positional embeddings. + encoder_pos_embed: Optional encoder positional embeddings. + Returns: + Output tensor. + """ + residual = input_tensor + self_attn_out = self._apply_self_attention(input_tensor, decoder_pos_embed) + # Normalize after self-attention + if self.pre_norm: + intermediate = self.norm2(self_attn_out) + else: + intermediate = self.norm1(self_attn_out) + cross_attn_out = self._apply_cross_attention(intermediate, encoder_output, decoder_pos_embed, encoder_pos_embed) + if self.pre_norm: + intermediate = self.norm3(cross_attn_out) + else: + intermediate = self.norm2(cross_attn_out) + ff_out = self._apply_feed_forward(intermediate) + output = residual + self.dropout3(ff_out) + if not self.pre_norm: + output = self.norm3(output) + return output + + +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: + """1D sinusoidal positional embeddings as in Attention is All You Need. + + Args: + num_positions: Number of token positions. + dimension: Embedding dimension. + Returns: + Tensor of shape (num_positions, dimension) with positional embeddings. + """ + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) + return torch.from_numpy(sinusoid_table).float() + + +class ACTSinusoidalPositionEmbedding2d(nn.Module): + """2D sinusoidal positional embeddings similar to Attention Is All You Need.""" + + def __init__(self, dimension: int): + """ + Args: + dimension: Embedding dimension. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + self._temperature = 10000 + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor: A (B, C, H, W) tensor. + Returns: + Tensor of shape (1, C, H, W) with sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(tensor[0, :1]) # (1, H, W) + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=tensor.device) // 2) / self.dimension + ) + x_range = x_range.unsqueeze(-1) / inverse_frequency + y_range = y_range.unsqueeze(-1) / inverse_frequency + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) + return pos_embed + + +def get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string.""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5d2f6cb5fef..14bb63f93f5 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -23,6 +23,7 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.actlanguage.configuration_act import ACTLanguageConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -46,6 +47,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.act.modeling_act import ACTPolicy return ACTPolicy + elif name == "actlanguage": + from lerobot.common.policies.actlanguage.modeling_act import ACTLanguagePolicy + + return ACTLanguagePolicy elif name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy @@ -65,6 +70,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return DiffusionConfig(**kwargs) elif policy_type == "act": return ACTConfig(**kwargs) + elif policy_type == "actlanguage": + return ACTLanguageConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) elif policy_type == "pi0":