From 8755a63a0a922d6dfb634c1bc54d6ccd0b4814ba Mon Sep 17 00:00:00 2001 From: Gabriel Cordeiro Bosse Date: Mon, 24 Mar 2025 16:30:41 -0300 Subject: [PATCH 1/3] Adicionei a policy act-language --- .../policies/actlanguage/configuration_act.py | 187 ++++ .../policies/actlanguage/modeling_act.py | 810 ++++++++++++++++++ lerobot/common/policies/factory.py | 7 + 3 files changed, 1004 insertions(+) create mode 100644 lerobot/common/policies/actlanguage/configuration_act.py create mode 100644 lerobot/common/policies/actlanguage/modeling_act.py diff --git a/lerobot/common/policies/actlanguage/configuration_act.py b/lerobot/common/policies/actlanguage/configuration_act.py new file mode 100644 index 0000000000..a7bb4f38a6 --- /dev/null +++ b/lerobot/common/policies/actlanguage/configuration_act.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("actlanguage") +@dataclass +class ACTLanguageConfig(PreTrainedConfig): + """Configuration class for the Action Chunking Transformers policy. + + Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and 'output_shapes`. + + Notes on the inputs and outputs: + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. + - If there are multiple keys beginning with "observation.images." they are treated as multiple camera + views. Right now we only support all images having the same shape. + - May optionally work without an "observation.state" key for the proprioceptive robot state. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + chunk_size: The size of the action prediction "chunks" in units of environment steps. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + This should be no greater than the chunk size. For example, if the chunk size size 100, you may + set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the + environment, and throws the other 50 out. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. + `None` means no pretrained weights. + replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated + convolution. + pre_norm: Whether to use "pre-norm" in the transformer blocks. + dim_model: The transformer blocks' main hidden dimension. + n_heads: The number of heads to use in the transformer blocks' multi-head attention. + dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward + layers. + feedforward_activation: The activation to use in the transformer block's feed-forward layers. + n_encoder_layers: The number of transformer layers to use for the transformer encoder. + n_decoder_layers: The number of transformer layers to use for the transformer decoder. + use_vae: Whether to use a variational objective during training. This introduces another transformer + which is used as the VAE's encoder (not to be confused with the transformer encoder - see + documentation in the policy class). + latent_dim: The VAE's latent dimension. + n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. + temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal + ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be + 1 when using this feature, as inference needs to happen at every step to form an ensemble. For + more information on how ensembling works, please see `ACTTemporalEnsembler`. + dropout: Dropout to use in the transformer layers (see code for details). + kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective + is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. + """ + + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 100 + n_action_steps: int = 100 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Architecture. + # Vision backbone. + vision_backbone: str = "resnet18" + pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" + replace_final_stride_with_dilation: int = False + # Transformer layers. + pre_norm: bool = False + dim_model: int = 512 + n_heads: int = 8 + dim_feedforward: int = 3200 + feedforward_activation: str = "relu" + n_encoder_layers: int = 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: int = 1 + # VAE. + use_vae: bool = True + latent_dim: int = 32 + n_vae_encoder_layers: int = 4 + + # Inference. + # Note: the value used in ACT when temporal ensembling is enabled is 0.01. + temporal_ensemble_coeff: float | None = None + + # Training and loss computation. + dropout: float = 0.1 + kl_weight: float = 10.0 + + # Training preset + optimizer_lr: float = 1e-5 + optimizer_weight_decay: float = 1e-4 + optimizer_lr_backbone: float = 1e-5 + task_vocab_size: int = 512 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1: + raise NotImplementedError( + "`n_action_steps` must be 1 when using temporal ensembling. This is " + "because the policy needs to be queried every step to compute the ensembled action." + ) + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + if not self.image_features and not self.env_state_feature: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/actlanguage/modeling_act.py b/lerobot/common/policies/actlanguage/modeling_act.py new file mode 100644 index 0000000000..4791d8a79a --- /dev/null +++ b/lerobot/common/policies/actlanguage/modeling_act.py @@ -0,0 +1,810 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Action Chunking Transformer Policy + +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. +""" + +import math +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d + +from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from transformers import AutoTokenizer + +class ACTLanguagePolicy(PreTrainedPolicy): + """ + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + """ + + config_class = ACTConfig + name = "act" + + def __init__( + self, + config: ACTConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.tokenizer = AutoTokenizer.from_pretrained("lerobot/pi0") + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = ACT(config) + + if config.temporal_ensemble_coeff is not None: + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.reset() + + def get_optim_params(self) -> dict: + # TODO(aliberts, rcadene): As of now, lr_backbone == lr + # Should we remove this and just `return self.parameters()`? + return [ + { + "params": [ + p + for n, p in self.named_parameters() + if not n.startswith("model.backbone") and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in self.named_parameters() + if n.startswith("model.backbone") and p.requires_grad + ], + "lr": self.config.optimizer_lr_backbone, + }, + ] + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.config.temporal_ensemble_coeff is not None: + self.temporal_ensembler.reset() + else: + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + batch = self.normalize_inputs(batch) + + # Tokenize task + task_texts = batch["task"] # Assume batch["task"] is a list of strings + tokenized = self.tokenizer( + task_texts, + padding=True, + return_tensors="pt", + return_attention_mask=True + ).to(self.device) + batch["task.input_ids"] = tokenized.input_ids + batch["task.attention_mask"] = tokenized.attention_mask + + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] + + # If we are doing temporal ensembling, do online updates where we keep track of the number of actions + # we are ensembling over. + if self.config.temporal_ensemble_coeff is not None: + actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) + actions = self.unnormalize_outputs({"action": actions})["action"] + action = self.temporal_ensembler.update(actions) + return action + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + actions = self.model(batch)[0][:, : self.config.n_action_steps] + + # TODO(rcadene): make _forward return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] + + batch = self.normalize_targets(batch) + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + loss_dict = {"l1_loss": l1_loss.item()} + if self.config.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kld_loss"] = mean_kld.item() + loss = l1_loss + mean_kld * self.config.kl_weight + else: + loss = l1_loss + + return loss, loss_dict + + +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the + coefficient works: + - Setting it to 0 uniformly weighs all actions. + - Setting it positive gives more weight to older actions. + - Setting it negative gives more weight to newer actions. + NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This + results in older actions being weighed more highly than newer actions (the experiments documented in + https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be + detrimental: doing so aggressively may diminish the benefits of action chunking). + + Here we use an online method for computing the average rather than caching a history of actions in + order to compute the average offline. For a simple 1D sequence it looks something like: + + ``` + import torch + + seq = torch.linspace(8, 8.5, 100) + print(seq) + + m = 0.01 + exp_weights = torch.exp(-m * torch.arange(len(seq))) + print(exp_weights) + + # Calculate offline + avg = (exp_weights * seq).sum() / exp_weights.sum() + print("offline", avg) + + # Calculate online + for i, item in enumerate(seq): + if i == 0: + avg = item + continue + avg *= exp_weights[:i].sum() + avg += item * exp_weights[i] + avg /= exp_weights[:i+1].sum() + print("online", avg) + ``` + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.reset() + + def reset(self): + """Resets the online computation variables.""" + self.ensembled_actions = None + # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. + self.ensembled_actions_count = None + + def update(self, actions: Tensor) -> Tensor: + """ + Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all + time steps, and pop/return the next batch of actions in the sequence. + """ + self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + if self.ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self.ensembled_actions = actions.clone() + # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor + # operations later. + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + ) + else: + # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the online update for those entries. + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + # The last action, which has no prior online average, needs to get concatenated onto the end. + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat( + [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + ) + # "Consume" the first action. + action, self.ensembled_actions, self.ensembled_actions_count = ( + self.ensembled_actions[:, 0], + self.ensembled_actions[:, 1:], + self.ensembled_actions_count[1:], + ) + return action + + +class ACT(nn.Module): + """Action Chunking Transformer: The underlying neural network for ACTPolicy. + + Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. + - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the + model that encodes the target data (a sequence of actions), and the condition (the robot + joint-space). + - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with + cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we + have an option to train this model without the variational objective (in which case we drop the + `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). + + Transformer + Used alone for inference + (acts as VAE decoder + during training) + ┌───────────────────────┐ + │ Outputs │ + │ ▲ │ + │ ┌─────►┌───────┐ │ + ┌──────┐ │ │ │Transf.│ │ + │ │ │ ├─────►│decoder│ │ + ┌────┴────┐ │ │ │ │ │ │ + │ │ │ │ ┌───┴───┬─►│ │ │ + │ VAE │ │ │ │ │ └───────┘ │ + │ encoder │ │ │ │Transf.│ │ + │ │ │ │ │encoder│ │ + └───▲─────┘ │ │ │ │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ + └───────────────────────┘ + """ + + def __init__(self, config: ACTConfig): + # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + super().__init__() + self.config = config + + if self.config.use_vae: + self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) + self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) + # Projection layer for joint-space configuration to hidden dimension. + if self.config.robot_state_feature: + self.vae_encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], config.dim_model + ) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear( + self.config.action_feature.shape[0], + config.dim_model, + ) + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) + # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch + # dimension. + num_input_token_encoder = 1 + config.chunk_size + if self.config.robot_state_feature: + num_input_token_encoder += 1 + self.register_buffer( + "vae_encoder_pos_enc", + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), + ) + self.task_embedding = nn.Embedding( + config.task_vocab_size, # Set this in ACTConfig + config.dim_model + ) + n_1d_tokens = 1 # latent + if config.robot_state_feature: + n_1d_tokens += 1 + if config.env_state_feature: + n_1d_tokens += 1 + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) + + # Backbone for image feature extraction. + if self.config.image_features: + backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final + # feature map). + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # Transformer (acts as VAE decoder when training with the variational objective). + self.encoder = ACTEncoder(config) + self.decoder = ACTDecoder(config) + + # Transformer encoder input projections. The tokens will be structured like + # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. + if self.config.robot_state_feature: + self.encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], config.dim_model + ) + if self.config.env_state_feature: + self.encoder_env_state_input_proj = nn.Linear( + self.config.env_state_feature.shape[0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) + if self.config.image_features: + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, config.dim_model, kernel_size=1 + ) + # Transformer encoder positional embeddings. + n_1d_tokens = 1 # for the latent + if self.config.robot_state_feature: + n_1d_tokens += 1 + if self.config.env_state_feature: + n_1d_tokens += 1 + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) + if self.config.image_features: + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) + + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) + + self._reset_parameters() + + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder). + + `batch` should have the following structure: + { + [robot_state_feature] (optional): (B, state_dim) batch of robot states. + + [image_features]: (B, n_cameras, C, H, W) batch of images. + AND/OR + [env_state_feature]: (B, env_dim) batch of environment states. + + [action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. + } + + Returns: + (B, chunk_size, action_dim) batch of action sequences + Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the + latent dimension. + """ + if self.config.use_vae and self.training: + assert "action" in batch, ( + "actions must be provided when using the variational objective in training mode." + ) + + if "observation.images" in batch: + batch_size = batch["observation.images"][0].shape[0] + else: + batch_size = batch["observation.environment_state"].shape[0] + + # Prepare the latent for input to the transformer encoder. + if self.config.use_vae and "action" in batch: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat( + self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size + ) # (B, 1, D) + if self.config.robot_state_feature: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + + if self.config.robot_state_feature: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) + + # Prepare fixed positional embedding. + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + + # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the + # sequence depending whether we use the input states or not (cls and robot state) + # False means not a padding token. + cls_joint_is_pad = torch.full( + (batch_size, 2 if self.config.robot_state_feature else 1), + False, + device=batch["observation.state"].device, + ) + key_padding_mask = torch.cat( + [cls_joint_is_pad, batch["action_is_pad"]], axis=1 + ) # (bs, seq+1 or 2) + + # Forward pass through VAE encoder to get the latent PDF parameters. + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), + pos_embed=pos_embed.permute(1, 0, 2), + key_padding_mask=key_padding_mask, + )[0] # select the class token, with shape (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + mu = latent_pdf_params[:, : self.config.latent_dim] + # This is 2log(sigma). Done this way to match the original implementation. + log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] + + # Sample the latent with the reparameterization trick. + latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + else: + # When not using the VAE encoder, we set the latent to be all zeros. + mu = log_sigma_x2 = None + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( + batch["observation.state"].device + ) + + # Prepare transformer encoder inputs. + encoder_in_tokens = [] + encoder_in_pos_embed = [] + # 1. Latent token + latent_token = self.encoder_latent_input_proj(latent_sample).unsqueeze(0) # (1, B, D) + encoder_in_tokens.append(latent_token) + encoder_in_pos_embed.append(self.encoder_1d_feature_pos_embed.weight[0].unsqueeze(0).unsqueeze(1)) + + # 2. Robot state token + if self.config.robot_state_feature: + robot_state_token = self.encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(0) # (1, B, D) + encoder_in_tokens.append(robot_state_token) + encoder_in_pos_embed.append(self.encoder_1d_feature_pos_embed.weight[1].unsqueeze(0).unsqueeze(1)) + + # 3. Task tokens + if "task.input_ids" in batch: + task_embeds = self.task_embedding(batch["task.input_ids"]) # (B, S_task, D) + task_embeds = einops.rearrange(task_embeds, "b s d -> s b d") # (S_task, B, D) + encoder_in_tokens.append(task_embeds) + + # Generate positional embeddings for task tokens + task_pos = create_sinusoidal_pos_embedding( + task_embeds.size(0), + self.config.dim_model + ).to(device=task_embeds.device).unsqueeze(1) # (S_task, 1, D) + encoder_in_pos_embed.append(task_pos) + # Environment state token. + if self.config.env_state_feature: + encoder_in_tokens.append( + self.encoder_env_state_input_proj(batch["observation.environment_state"]) + ) + + # Camera observation features and positional embeddings. + if self.config.image_features: + all_cam_features = [] + all_cam_pos_embeds = [] + + # For a list of images, the H and W may vary but H*W is constant. + for img in batch["observation.images"]: + cam_features = self.backbone(img)["feature_map"] + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) + + # Rearrange features to (sequence, batch, dim). + cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") + cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") + + all_cam_features.append(cam_features) + all_cam_pos_embeds.append(cam_pos_embed) + + encoder_in_tokens.extend(all_cam_features) + encoder_in_pos_embed.extend(all_cam_pos_embeds) + + # Concatenate all tokens along sequence dimension + encoder_in_tokens = torch.cat(encoder_in_tokens, dim=0) # (Total_Tokens, B, D) + encoder_in_pos_embed = torch.cat(encoder_in_pos_embed, dim=0) # (Total_Tokens, 1, D) + + # Rest of original forward pass... + encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) + + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer + decoder_in = torch.zeros( + (self.config.chunk_size, batch_size, self.config.dim_model), + dtype=encoder_in_pos_embed.dtype, + device=encoder_in_pos_embed.device, + ) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=encoder_in_pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + + # Move back to (B, S, C). + decoder_out = decoder_out.transpose(0, 1) + + actions = self.action_head(decoder_out) + + return actions, (mu, log_sigma_x2) + + +class ACTEncoder(nn.Module): + """Convenience module for running multiple encoder layers, maybe followed by normalization.""" + + def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): + super().__init__() + self.is_vae_encoder = is_vae_encoder + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() + + def forward( + self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: + for layer in self.layers: + x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) + x = self.norm(x) + return x + + +class ACTEncoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = x if pos_embed is None else x + pos_embed + x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) + x = x[0] # note: [0] to select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.pre_norm: + x = self.norm2(x) + return x + + +class ACTDecoder(nn.Module): + def __init__(self, config: ACTConfig): + """Convenience module for running multiple decoder layers followed by normalization.""" + super().__init__() + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + for layer in self.layers: + x = layer( + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + ) + if self.norm is not None: + x = self.norm(x) + return x + + +class ACTDecoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.norm3 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.dropout3 = nn.Dropout(config.dropout) + + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos_embed), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), + value=encoder_out, + )[0] # select just the output, not the attention weights + x = skip + self.dropout2(x) + if self.pre_norm: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.pre_norm: + x = self.norm3(x) + return x + + +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: + """1D sinusoidal positional embeddings as in Attention is All You Need. + + Args: + num_positions: Number of token positions required. + Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). + + """ + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.from_numpy(sinusoid_table).float() + + +class ACTSinusoidalPositionEmbedding2d(nn.Module): + """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. + + The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H + for the vertical direction, and 1/W for the horizontal direction. + """ + + def __init__(self, dimension: int): + """ + Args: + dimension: The desired dimension of the embeddings. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + # Inverse "common ratio" for the geometric progression in sinusoid frequencies. + self._temperature = 10000 + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + Returns: + A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(x[0, :1]) # (1, H, W) + # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations + # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + + # "Normalize" the position index such that it ranges in [0, 2π]. + # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range + # are non-zero by construction. This is an artifact of the original code. + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + ) + + x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + + # Note: this stack then flatten operation results in interleaved sine and cosine terms. + # pos_embed_x and pos_embed_y are (1, H, W, C // 2). + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + + return pos_embed + + +def get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string.""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5d2f6cb5fe..14bb63f93f 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -23,6 +23,7 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.actlanguage.configuration_act import ACTLanguageConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -46,6 +47,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.act.modeling_act import ACTPolicy return ACTPolicy + elif name == "actlanguage": + from lerobot.common.policies.actlanguage.modeling_act import ACTLanguagePolicy + + return ACTLanguagePolicy elif name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy @@ -65,6 +70,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return DiffusionConfig(**kwargs) elif policy_type == "act": return ACTConfig(**kwargs) + elif policy_type == "actlanguage": + return ACTLanguageConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) elif policy_type == "pi0": From bef3b70727b3e9593c5ae3d3a3261fba88aca4fb Mon Sep 17 00:00:00 2001 From: Gabriel Cordeiro Bosse Date: Sat, 5 Apr 2025 14:59:29 -0300 Subject: [PATCH 2/3] Melhorei a policy e arrumei alguns Bugs. --- .../policies/actlanguage/configuration_act.py | 2 +- .../policies/actlanguage/modeling_act.py | 693 +++++++++--------- 2 files changed, 334 insertions(+), 361 deletions(-) diff --git a/lerobot/common/policies/actlanguage/configuration_act.py b/lerobot/common/policies/actlanguage/configuration_act.py index a7bb4f38a6..d443b79ee9 100644 --- a/lerobot/common/policies/actlanguage/configuration_act.py +++ b/lerobot/common/policies/actlanguage/configuration_act.py @@ -136,7 +136,7 @@ class ACTLanguageConfig(PreTrainedConfig): optimizer_lr: float = 1e-5 optimizer_weight_decay: float = 1e-4 optimizer_lr_backbone: float = 1e-5 - task_vocab_size: int = 512 + task_vocab_size: int = 320000 def __post_init__(self): super().__post_init__() diff --git a/lerobot/common/policies/actlanguage/modeling_act.py b/lerobot/common/policies/actlanguage/modeling_act.py index 4791d8a79a..77c21e9aa0 100644 --- a/lerobot/common/policies/actlanguage/modeling_act.py +++ b/lerobot/common/policies/actlanguage/modeling_act.py @@ -1,18 +1,16 @@ #!/usr/bin/env python -# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. -# +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS +# OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations under the License. """Action Chunking Transformer Policy As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). @@ -22,12 +20,12 @@ import math from collections import deque from itertools import chain -from typing import Callable +from typing import Callable, Optional import einops import numpy as np import torch -import torch.nn.functional as F # noqa: N812 +import torch.nn.functional as F import torchvision from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter @@ -38,6 +36,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from transformers import AutoTokenizer + class ACTLanguagePolicy(PreTrainedPolicy): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost @@ -50,7 +49,7 @@ class ACTLanguagePolicy(PreTrainedPolicy): def __init__( self, config: ACTConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, + dataset_stats: Optional[dict[str, dict[str, Tensor]]] = None, ): """ Args: @@ -81,20 +80,25 @@ def __init__( self.reset() def get_optim_params(self) -> dict: - # TODO(aliberts, rcadene): As of now, lr_backbone == lr - # Should we remove this and just `return self.parameters()`? + """ + Returns optimizer parameters for the policy. + + This method groups parameters into different optimizer parameter groups, + separating those belonging to the model backbone with a distinct learning rate. + + Returns: + A list of dictionaries specifying parameter groups and their optimizer settings. + """ return [ { "params": [ - p - for n, p in self.named_parameters() + p for n, p in self.named_parameters() if not n.startswith("model.backbone") and p.requires_grad ] }, { "params": [ - p - for n, p in self.named_parameters() + p for n, p in self.named_parameters() if n.startswith("model.backbone") and p.requires_grad ], "lr": self.config.optimizer_lr_backbone, @@ -108,7 +112,7 @@ def reset(self): else: self._action_queue = deque([], maxlen=self.config.n_action_steps) - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -127,40 +131,33 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: padding=True, return_tensors="pt", return_attention_mask=True - ).to(self.device) + ) batch["task.input_ids"] = tokenized.input_ids batch["task.attention_mask"] = tokenized.attention_mask if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict(batch) # shallow copy to avoid modifying original batch["observation.images"] = [batch[key] for key in self.config.image_features] - # If we are doing temporal ensembling, do online updates where we keep track of the number of actions - # we are ensembling over. + # Temporal ensembling logic. if self.config.temporal_ensemble_coeff is not None: actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] action = self.temporal_ensembler.update(actions) return action - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. + # Action queue logic for n_action_steps > 1. if len(self._action_queue) == 0: actions = self.model(batch)[0][:, : self.config.n_action_steps] - - # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Optional[tuple[Tensor, Tensor]]]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict(batch) batch["observation.images"] = [batch[key] for key in self.config.image_features] batch = self.normalize_targets(batch) @@ -172,10 +169,6 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: - # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) @@ -192,43 +185,11 @@ def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. - They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the - coefficient works: - - Setting it to 0 uniformly weighs all actions. - - Setting it positive gives more weight to older actions. - - Setting it negative gives more weight to newer actions. - NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This - results in older actions being weighed more highly than newer actions (the experiments documented in - https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be - detrimental: doing so aggressively may diminish the benefits of action chunking). - - Here we use an online method for computing the average rather than caching a history of actions in - order to compute the average offline. For a simple 1D sequence it looks something like: - - ``` - import torch - - seq = torch.linspace(8, 8.5, 100) - print(seq) - - m = 0.01 - exp_weights = torch.exp(-m * torch.arange(len(seq))) - print(exp_weights) - - # Calculate offline - avg = (exp_weights * seq).sum() / exp_weights.sum() - print("offline", avg) - - # Calculate online - for i, item in enumerate(seq): - if i == 0: - avg = item - continue - avg *= exp_weights[:i].sum() - avg += item * exp_weights[i] - avg /= exp_weights[:i+1].sum() - print("online", avg) - ``` + They are then normalized to sum to 1. + + Args: + temporal_ensemble_coeff: Coefficient controlling the weight decay. + chunk_size: The number of actions in a chunk. """ self.chunk_size = chunk_size self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) @@ -236,106 +197,70 @@ def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: self.reset() def reset(self): - """Resets the online computation variables.""" + """Reset the ensembling variables.""" self.ensembled_actions = None - # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. self.ensembled_actions_count = None def update(self, actions: Tensor) -> Tensor: """ - Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all - time steps, and pop/return the next batch of actions in the sequence. + Update the temporal ensemble with new actions and return the next ensembled action. + + Args: + actions: Tensor of shape (batch, chunk_size, action_dim) with new actions. + Returns: + Tensor with the ensembled action. """ self.ensemble_weights = self.ensemble_weights.to(device=actions.device) self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) if self.ensembled_actions is None: - # Initializes `self._ensembled_action` to the sequence of actions predicted during the first - # time step of the episode. self.ensembled_actions = actions.clone() - # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor - # operations later. self.ensembled_actions_count = torch.ones( (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device ) else: - # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute - # the online update for those entries. self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) - # The last action, which has no prior online average, needs to get concatenated onto the end. self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) self.ensembled_actions_count = torch.cat( [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] ) - # "Consume" the first action. - action, self.ensembled_actions, self.ensembled_actions_count = ( - self.ensembled_actions[:, 0], - self.ensembled_actions[:, 1:], - self.ensembled_actions_count[1:], - ) + action = self.ensembled_actions[:, 0] + self.ensembled_actions = self.ensembled_actions[:, 1:] + self.ensembled_actions_count = self.ensembled_actions_count[1:] return action class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. - Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. - - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the - model that encodes the target data (a sequence of actions), and the condition (the robot - joint-space). - - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with - cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we - have an option to train this model without the variational objective (in which case we drop the - `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). - - Transformer - Used alone for inference - (acts as VAE decoder - during training) - ┌───────────────────────┐ - │ Outputs │ - │ ▲ │ - │ ┌─────►┌───────┐ │ - ┌──────┐ │ │ │Transf.│ │ - │ │ │ ├─────►│decoder│ │ - ┌────┴────┐ │ │ │ │ │ │ - │ │ │ │ ┌───┴───┬─►│ │ │ - │ VAE │ │ │ │ │ └───────┘ │ - │ encoder │ │ │ │Transf.│ │ - │ │ │ │ │encoder│ │ - └───▲─────┘ │ │ │ │ │ - │ │ │ └▲──▲─▲─┘ │ - │ │ │ │ │ │ │ - inputs └─────┼──┘ │ image emb. │ - │ state emb. │ - └───────────────────────┘ + This model includes an optional VAE encoder, task embeddings, image feature backbone, transformer encoder/decoder, + and a final regression head for action prediction. """ def __init__(self, config: ACTConfig): - # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. - # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + """ + Initializes the Action Chunking Transformer model. + + Args: + config: An instance of ACTConfig containing model configuration parameters. + """ super().__init__() self.config = config if self.config.use_vae: self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) - # Projection layer for joint-space configuration to hidden dimension. if self.config.robot_state_feature: self.vae_encoder_robot_state_input_proj = nn.Linear( self.config.robot_state_feature.shape[0], config.dim_model ) - # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( self.config.action_feature.shape[0], config.dim_model, ) - # Projection layer from the VAE encoder's output to the latent distribution's parameter space. self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) - # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch - # dimension. num_input_token_encoder = 1 + config.chunk_size if self.config.robot_state_feature: num_input_token_encoder += 1 @@ -344,34 +269,27 @@ def __init__(self, config: ACTConfig): create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), ) self.task_embedding = nn.Embedding( - config.task_vocab_size, # Set this in ACTConfig + config.task_vocab_size, config.dim_model ) - n_1d_tokens = 1 # latent + n_1d_tokens = 1 if config.robot_state_feature: n_1d_tokens += 1 if config.env_state_feature: n_1d_tokens += 1 self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) - # Backbone for image feature extraction. if self.config.image_features: backbone_model = getattr(torchvision.models, config.vision_backbone)( replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], weights=config.pretrained_backbone_weights, norm_layer=FrozenBatchNorm2d, ) - # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final - # feature map). - # Note: The forward method of this returns a dict: {"feature_map": output}. self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) - # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) self.decoder = ACTDecoder(config) - # Transformer encoder input projections. The tokens will be structured like - # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. if self.config.robot_state_feature: self.encoder_robot_state_input_proj = nn.Linear( self.config.robot_state_feature.shape[0], config.dim_model @@ -385,8 +303,7 @@ def __init__(self, config: ACTConfig): self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) - # Transformer encoder positional embeddings. - n_1d_tokens = 1 # for the latent + n_1d_tokens = 1 if self.config.robot_state_feature: n_1d_tokens += 1 if self.config.env_state_feature: @@ -395,11 +312,7 @@ def __init__(self, config: ACTConfig): if self.config.image_features: self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) - # Transformer decoder. - # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) - - # Final action regression head on the output of the transformer's decoder. self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) self._reset_parameters() @@ -410,59 +323,59 @@ def _reset_parameters(self): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: - """A forward pass through the Action Chunking Transformer (with optional VAE encoder). - - `batch` should have the following structure: - { - [robot_state_feature] (optional): (B, state_dim) batch of robot states. + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Optional[tuple[Tensor, Tensor]]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder).""" + batch = self._prepare_batch(batch) + batch_size = self._get_batch_size(batch) + latent_sample, mu, log_sigma_x2 = self._prepare_latent(batch, batch_size) + encoder_in_tokens, encoder_in_pos_embed = self._prepare_encoder_inputs(batch, latent_sample, batch_size) + encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) + decoder_in = self._prepare_decoder_input(batch_size, encoder_in_pos_embed.dtype, encoder_in_pos_embed.device) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=encoder_in_pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + decoder_out = decoder_out.transpose(0, 1) + actions = self.action_head(decoder_out) + return actions, (mu, log_sigma_x2) - [image_features]: (B, n_cameras, C, H, W) batch of images. - AND/OR - [env_state_feature]: (B, env_dim) batch of environment states. + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare and normalize the batch input and add image features if available.""" + if self.config.image_features: + batch = dict(batch) + batch["observation.images"] = [batch[key] for key in self.config.image_features] + return batch - [action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. - } + def _get_batch_size(self, batch: dict[str, Tensor]) -> int: + """Extract the batch size from the batch dictionary.""" + if "observation.images" in batch: + return batch["observation.images"][0].shape[0] + return batch["observation.environment_state"].shape[0] + def _prepare_latent(self, batch: dict[str, Tensor], batch_size: int) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Prepare the latent token for transformer encoder. + Returns: - (B, chunk_size, action_dim) batch of action sequences - Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the - latent dimension. + latent_sample: Tensor of shape (batch_size, latent_dim) + mu: Mean of the latent distribution (or None) + log_sigma_x2: Log variance (or None) """ - if self.config.use_vae and self.training: - assert "action" in batch, ( - "actions must be provided when using the variational objective in training mode." - ) - - if "observation.images" in batch: - batch_size = batch["observation.images"][0].shape[0] - else: - batch_size = batch["observation.environment_state"].shape[0] - - # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: - # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size - ) # (B, 1, D) + ) if self.config.robot_state_feature: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) - robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - + robot_state_embed = robot_state_embed.unsqueeze(1) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) if self.config.robot_state_feature: - vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] else: vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = torch.cat(vae_encoder_input, axis=1) - - # Prepare fixed positional embedding. - # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. - pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) - - # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the - # sequence depending whether we use the input states or not (cls and robot state) - # False means not a padding token. + pos_embed = self.vae_encoder_pos_enc.clone().detach() cls_joint_is_pad = torch.full( (batch_size, 2 if self.config.robot_state_feature else 1), False, @@ -470,166 +383,202 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso ) key_padding_mask = torch.cat( [cls_joint_is_pad, batch["action_is_pad"]], axis=1 - ) # (bs, seq+1 or 2) - - # Forward pass through VAE encoder to get the latent PDF parameters. + ) cls_token_out = self.vae_encoder( vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2), key_padding_mask=key_padding_mask, - )[0] # select the class token, with shape (B, D) + )[0] latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) mu = latent_pdf_params[:, : self.config.latent_dim] - # This is 2log(sigma). Done this way to match the original implementation. log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] - - # Sample the latent with the reparameterization trick. latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + return latent_sample, mu, log_sigma_x2 else: - # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None - # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( batch["observation.state"].device ) + return latent_sample, mu, log_sigma_x2 - # Prepare transformer encoder inputs. - encoder_in_tokens = [] - encoder_in_pos_embed = [] - # 1. Latent token - latent_token = self.encoder_latent_input_proj(latent_sample).unsqueeze(0) # (1, B, D) - encoder_in_tokens.append(latent_token) - encoder_in_pos_embed.append(self.encoder_1d_feature_pos_embed.weight[0].unsqueeze(0).unsqueeze(1)) - - # 2. Robot state token + def _prepare_encoder_inputs(self, batch: dict[str, Tensor], latent_sample: Tensor, batch_size: int) -> tuple[Tensor, Tensor]: + """Prepare tokens and positional embeddings for transformer encoder. + + Returns: + encoder_in_tokens: Tensor of shape (total_tokens, batch_size, dim_model) + encoder_in_pos_embed: Tensor of shape (total_tokens, 1, dim_model) + """ + tokens = [] + pos_embeds = [] + latent_token = self.encoder_latent_input_proj(latent_sample).unsqueeze(0) + tokens.append(latent_token) + pos_embeds.append(self.encoder_1d_feature_pos_embed.weight[0].unsqueeze(0).unsqueeze(1)) if self.config.robot_state_feature: - robot_state_token = self.encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(0) # (1, B, D) - encoder_in_tokens.append(robot_state_token) - encoder_in_pos_embed.append(self.encoder_1d_feature_pos_embed.weight[1].unsqueeze(0).unsqueeze(1)) - - # 3. Task tokens + robot_state_token = self.encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(0) + tokens.append(robot_state_token) + pos_embeds.append(self.encoder_1d_feature_pos_embed.weight[1].unsqueeze(0).unsqueeze(1)) if "task.input_ids" in batch: - task_embeds = self.task_embedding(batch["task.input_ids"]) # (B, S_task, D) - task_embeds = einops.rearrange(task_embeds, "b s d -> s b d") # (S_task, B, D) - encoder_in_tokens.append(task_embeds) - - # Generate positional embeddings for task tokens - task_pos = create_sinusoidal_pos_embedding( - task_embeds.size(0), - self.config.dim_model - ).to(device=task_embeds.device).unsqueeze(1) # (S_task, 1, D) - encoder_in_pos_embed.append(task_pos) - # Environment state token. + task_embeds = self.task_embedding(batch["task.input_ids"]) + task_embeds = einops.rearrange(task_embeds, "b s d -> s b d") + tokens.append(task_embeds) + task_pos = create_sinusoidal_pos_embedding(task_embeds.size(0), self.config.dim_model).to( + device=task_embeds.device + ).unsqueeze(1) + pos_embeds.append(task_pos) if self.config.env_state_feature: - encoder_in_tokens.append( - self.encoder_env_state_input_proj(batch["observation.environment_state"]) - ) - - # Camera observation features and positional embeddings. + tokens.append(self.encoder_env_state_input_proj(batch["observation.environment_state"]).unsqueeze(0)) if self.config.image_features: - all_cam_features = [] - all_cam_pos_embeds = [] - - # For a list of images, the H and W may vary but H*W is constant. for img in batch["observation.images"]: cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) - - # Rearrange features to (sequence, batch, dim). cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") - - all_cam_features.append(cam_features) - all_cam_pos_embeds.append(cam_pos_embed) - - encoder_in_tokens.extend(all_cam_features) - encoder_in_pos_embed.extend(all_cam_pos_embeds) - - # Concatenate all tokens along sequence dimension - encoder_in_tokens = torch.cat(encoder_in_tokens, dim=0) # (Total_Tokens, B, D) - encoder_in_pos_embed = torch.cat(encoder_in_pos_embed, dim=0) # (Total_Tokens, 1, D) - - # Rest of original forward pass... - encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) - - # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer - decoder_in = torch.zeros( + tokens.append(cam_features) + pos_embeds.append(cam_pos_embed) + + # DEBUG: shows original shapes of the tokens + # print(f"[DEBUG] token shapes before adjustment: {[t.shape for t in tokens]}") + # print(f"[DEBUG] pos_embeds shapes before adjustment: {[t.shape for t in pos_embeds]}") + + # Helper function to adjust the batch dimension + def adjust_batch_dim(t: torch.Tensor, desired_bs: int) -> torch.Tensor: + if t.shape[1] == desired_bs: + return t + elif t.shape[1] == 1: + return t.expand(t.shape[0], desired_bs, t.shape[2]) + else: + raise ValueError(f"Incompatible batch size: expected {desired_bs}, but got {t.shape[1]} in tensor {t}") + + # Adjust the tokens and pos_embeds tensors to have the expected batch size + tokens = [adjust_batch_dim(t, batch_size) for t in tokens] + pos_embeds = [adjust_batch_dim(t, batch_size) for t in pos_embeds] + + # print(f"[DEBUG] adjusted token shapes: {[t.shape for t in tokens]}") + # print(f"[DEBUG] adjusted pos_embeds shapes: {[t.shape for t in pos_embeds]}") + + encoder_in_tokens = torch.cat(tokens, dim=0) + encoder_in_pos_embed = torch.cat(pos_embeds, dim=0) + return encoder_in_tokens, encoder_in_pos_embed + + def _prepare_decoder_input(self, batch_size: int, dtype, device) -> Tensor: + """Prepare the zero-initialized input for the decoder. + + Returns: + Tensor of shape (chunk_size, batch_size, dim_model) + """ + return torch.zeros( (self.config.chunk_size, batch_size, self.config.dim_model), - dtype=encoder_in_pos_embed.dtype, - device=encoder_in_pos_embed.device, - ) - decoder_out = self.decoder( - decoder_in, - encoder_out, - encoder_pos_embed=encoder_in_pos_embed, - decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + dtype=dtype, + device=device, ) - # Move back to (B, S, C). - decoder_out = decoder_out.transpose(0, 1) - - actions = self.action_head(decoder_out) - - return actions, (mu, log_sigma_x2) - class ACTEncoder(nn.Module): - """Convenience module for running multiple encoder layers, maybe followed by normalization.""" + """Convenience module for running multiple encoder layers, optionally with normalization.""" def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): + """ + Initializes the ACTEncoder module. + + Args: + config: An instance of ACTConfig with model configuration parameters. + is_vae_encoder: Flag indicating whether the encoder is used as a VAE encoder. + """ super().__init__() self.is_vae_encoder = is_vae_encoder num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() - def forward( - self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None - ) -> Tensor: + def forward(self, x: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Pass the input through encoder layers and apply normalization. + + Args: + x: Input tensor of shape (seq_length, batch_size, embedding_dim). + pos_embed: Optional positional embeddings. + key_padding_mask: Optional mask for padded tokens. + Returns: + Encoded tensor. + """ for layer in self.layers: x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) - x = self.norm(x) - return x + return self.norm(x) class ACTEncoderLayer(nn.Module): def __init__(self, config: ACTConfig): + """ + Initializes an ACTEncoderLayer with self-attention and feed-forward network. + + Args: + config: An instance of ACTConfig containing layer configuration parameters. + """ super().__init__() self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) - - # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) self.dropout = nn.Dropout(config.dropout) self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) - self.norm1 = nn.LayerNorm(config.dim_model) self.norm2 = nn.LayerNorm(config.dim_model) self.dropout1 = nn.Dropout(config.dropout) self.dropout2 = nn.Dropout(config.dropout) - self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: - skip = x - if self.pre_norm: - x = self.norm1(x) - q = k = x if pos_embed is None else x + pos_embed - x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) - x = x[0] # note: [0] to select just the output, not the attention weights - x = skip + self.dropout1(x) + def _apply_self_attention(self, x: Tensor, pos_embed: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: + """ + Applies self-attention to the input tensor. + + Args: + x: Input tensor. + pos_embed: Optional positional embeddings. + key_padding_mask: Optional mask for padded tokens. + Returns: + Tensor after self-attention and dropout. + """ + norm_x = self.norm1(x) if self.pre_norm else x + if pos_embed is not None: + norm_x = norm_x + pos_embed + attn_output = self.self_attn(norm_x, norm_x, value=norm_x, key_padding_mask=key_padding_mask)[0] + return x + self.dropout1(attn_output) + + def _apply_feed_forward(self, x: Tensor) -> Tensor: + """ + Applies the feed-forward network. + + Args: + x: Input tensor. + Returns: + Tensor after feed-forward processing. + """ + ff_output = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return ff_output + + def forward(self, input_tensor: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Forward pass for the encoder layer. + + Args: + input_tensor: Input tensor of shape (seq_length, batch_size, embedding_dim). + pos_embed: Optional positional embeddings. + key_padding_mask: Optional padding mask. + Returns: + Output tensor. + """ + residual = input_tensor + attn_out = self._apply_self_attention(input_tensor, pos_embed, key_padding_mask) + # Apply normalization after self-attention if self.pre_norm: - skip = x - x = self.norm2(x) + intermediate = self.norm2(attn_out) else: - x = self.norm1(x) - skip = x - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout2(x) + intermediate = self.norm1(attn_out) + ff_out = self._apply_feed_forward(intermediate) + output = residual + self.dropout2(ff_out) if not self.pre_norm: - x = self.norm2(x) - return x + output = self.norm2(output) + return output class ACTDecoder(nn.Module): @@ -643,159 +592,183 @@ def forward( self, x: Tensor, encoder_out: Tensor, - decoder_pos_embed: Tensor | None = None, - encoder_pos_embed: Tensor | None = None, + decoder_pos_embed: Optional[Tensor] = None, + encoder_pos_embed: Optional[Tensor] = None, ) -> Tensor: + """ + Pass the input through decoder layers and apply normalization. + + Args: + x: Input tensor of shape (decoder_seq_length, batch_size, embedding_dim). + encoder_out: Encoder output tensor. + decoder_pos_embed: Optional decoder positional embeddings. + encoder_pos_embed: Optional encoder positional embeddings. + Returns: + Decoded tensor. + """ for layer in self.layers: - x = layer( - x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed - ) - if self.norm is not None: - x = self.norm(x) - return x + x = layer(x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed) + return self.norm(x) if self.norm is not None else x class ACTDecoderLayer(nn.Module): def __init__(self, config: ACTConfig): + """ + Initializes an ACTDecoderLayer with self-attention, cross-attention, and feed-forward network. + + Args: + config: An instance of ACTConfig containing layer configuration parameters. + """ super().__init__() self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) - - # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) self.dropout = nn.Dropout(config.dropout) self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) - self.norm1 = nn.LayerNorm(config.dim_model) self.norm2 = nn.LayerNorm(config.dim_model) self.norm3 = nn.LayerNorm(config.dim_model) self.dropout1 = nn.Dropout(config.dropout) self.dropout2 = nn.Dropout(config.dropout) self.dropout3 = nn.Dropout(config.dropout) - self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: - return tensor if pos_embed is None else tensor + pos_embed + def _apply_self_attention(self, x: Tensor, decoder_pos_embed: Optional[Tensor]) -> Tensor: + """ + Applies self-attention for the decoder. + + Args: + x: Input tensor. + decoder_pos_embed: Optional positional embeddings. + Returns: + Tensor after self-attention. + """ + norm_x = self.norm1(x) if self.pre_norm else x + if decoder_pos_embed is not None: + norm_x = norm_x + decoder_pos_embed + attn_output = self.self_attn(norm_x, norm_x, value=norm_x)[0] + return x + self.dropout1(attn_output) + + def _apply_cross_attention( + self, x: Tensor, encoder_output: Tensor, decoder_pos_embed: Optional[Tensor], encoder_pos_embed: Optional[Tensor] + ) -> Tensor: + """ + Applies cross-attention between decoder input and encoder output. + + Args: + x: Decoder input tensor. + encoder_output: Encoder output tensor. + decoder_pos_embed: Optional decoder positional embeddings. + encoder_pos_embed: Optional encoder positional embeddings. + Returns: + Tensor after cross-attention. + """ + query = x if decoder_pos_embed is None else x + decoder_pos_embed + key = encoder_output if encoder_pos_embed is None else encoder_output + encoder_pos_embed + cross_attn_out = self.multihead_attn(query=query, key=key, value=encoder_output)[0] + return x + self.dropout2(cross_attn_out) + + def _apply_feed_forward(self, x: Tensor) -> Tensor: + """ + Applies the feed-forward network. + + Args: + x: Input tensor. + Returns: + Tensor after feed-forward processing. + """ + ff_out = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return ff_out def forward( self, - x: Tensor, - encoder_out: Tensor, - decoder_pos_embed: Tensor | None = None, - encoder_pos_embed: Tensor | None = None, + input_tensor: Tensor, + encoder_output: Tensor, + decoder_pos_embed: Optional[Tensor] = None, + encoder_pos_embed: Optional[Tensor] = None, ) -> Tensor: """ + Forward pass for the decoder layer. + Args: - x: (Decoder Sequence, Batch, Channel) tensor of input tokens. - encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are - cross-attending with. - decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). - encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + input_tensor: Input tensor (decoder sequence, batch, channel). + encoder_output: Encoder output tensor. + decoder_pos_embed: Optional decoder positional embeddings. + encoder_pos_embed: Optional encoder positional embeddings. Returns: - (DS, B, C) tensor of decoder output features. + Output tensor. """ - skip = x + residual = input_tensor + self_attn_out = self._apply_self_attention(input_tensor, decoder_pos_embed) + # Normalize after self-attention if self.pre_norm: - x = self.norm1(x) - q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights - x = skip + self.dropout1(x) - if self.pre_norm: - skip = x - x = self.norm2(x) + intermediate = self.norm2(self_attn_out) else: - x = self.norm1(x) - skip = x - x = self.multihead_attn( - query=self.maybe_add_pos_embed(x, decoder_pos_embed), - key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), - value=encoder_out, - )[0] # select just the output, not the attention weights - x = skip + self.dropout2(x) + intermediate = self.norm1(self_attn_out) + cross_attn_out = self._apply_cross_attention(intermediate, encoder_output, decoder_pos_embed, encoder_pos_embed) if self.pre_norm: - skip = x - x = self.norm3(x) + intermediate = self.norm3(cross_attn_out) else: - x = self.norm2(x) - skip = x - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout3(x) + intermediate = self.norm2(cross_attn_out) + ff_out = self._apply_feed_forward(intermediate) + output = residual + self.dropout3(ff_out) if not self.pre_norm: - x = self.norm3(x) - return x + output = self.norm3(output) + return output def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: """1D sinusoidal positional embeddings as in Attention is All You Need. Args: - num_positions: Number of token positions required. - Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). - + num_positions: Number of token positions. + dimension: Embedding dimension. + Returns: + Tensor of shape (num_positions, dimension) with positional embeddings. """ - def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) return torch.from_numpy(sinusoid_table).float() class ACTSinusoidalPositionEmbedding2d(nn.Module): - """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. - - The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H - for the vertical direction, and 1/W for the horizontal direction. - """ + """2D sinusoidal positional embeddings similar to Attention Is All You Need.""" def __init__(self, dimension: int): """ Args: - dimension: The desired dimension of the embeddings. + dimension: Embedding dimension. """ super().__init__() self.dimension = dimension self._two_pi = 2 * math.pi self._eps = 1e-6 - # Inverse "common ratio" for the geometric progression in sinusoid frequencies. self._temperature = 10000 def forward(self, x: Tensor) -> Tensor: """ Args: - x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + x: A (B, C, H, W) tensor. Returns: - A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + Tensor of shape (1, C, H, W) with sinusoidal positional embeddings. """ not_mask = torch.ones_like(x[0, :1]) # (1, H, W) - # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations - # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. y_range = not_mask.cumsum(1, dtype=torch.float32) x_range = not_mask.cumsum(2, dtype=torch.float32) - - # "Normalize" the position index such that it ranges in [0, 2π]. - # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range - # are non-zero by construction. This is an artifact of the original code. y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi - inverse_frequency = self._temperature ** ( 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension ) - - x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) - y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) - - # Note: this stack then flatten operation results in interleaved sine and cosine terms. - # pos_embed_x and pos_embed_y are (1, H, W, C // 2). + x_range = x_range.unsqueeze(-1) / inverse_frequency + y_range = y_range.unsqueeze(-1) / inverse_frequency pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) - + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) return pos_embed From 358f37bba23ba28d9eefaf41f0c482725cbe7642 Mon Sep 17 00:00:00 2001 From: Gabriel Cordeiro Bosse Date: Sat, 5 Apr 2025 16:07:14 -0300 Subject: [PATCH 3/3] mensagem do commit --- .../policies/actlanguage/modeling_act.py | 737 ++++++++++++------ 1 file changed, 504 insertions(+), 233 deletions(-) diff --git a/lerobot/common/policies/actlanguage/modeling_act.py b/lerobot/common/policies/actlanguage/modeling_act.py index 77c21e9aa0..245a3e9642 100644 --- a/lerobot/common/policies/actlanguage/modeling_act.py +++ b/lerobot/common/policies/actlanguage/modeling_act.py @@ -116,68 +116,178 @@ def reset(self): def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. + Manages action queue and delegates to appropriate sub-methods for: + 1. Input normalization + 2. Task tokenization + 3. Image feature handling + 4. Temporal ensembling + 5. Action queue management + + Args: + batch: Input observations including task descriptions and environment states + + Returns: + Selected action tensor """ self.eval() + + batch = self._normalize_inputs(batch) + self._tokenize_tasks(batch) + self._prepare_image_features(batch) + + if self.config.temporal_ensemble_coeff is not None: + return self._apply_temporal_ensembling(batch) + + return self._get_queued_action(batch) - batch = self.normalize_inputs(batch) + def _normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Normalize input observations using configured normalization parameters. - # Tokenize task - task_texts = batch["task"] # Assume batch["task"] is a list of strings + Args: + batch: Raw input observations + + Returns: + Normalized input batch + """ + return self.normalize_inputs(batch) + + def _tokenize_tasks(self, batch: dict[str, Tensor]) -> None: + """Tokenize task descriptions in the input batch. + + Adds 'task.input_ids' and 'task.attention_mask' to the batch. + + Args: + batch: Input batch containing 'task' key with text descriptions + """ + task_texts = batch["task"] tokenized = self.tokenizer( task_texts, padding=True, - return_tensors="pt", + return_tensors="pt", return_attention_mask=True ) batch["task.input_ids"] = tokenized.input_ids batch["task.attention_mask"] = tokenized.attention_mask + def _prepare_image_features(self, batch: dict[str, Tensor]) -> None: + """Prepare image features from configured input keys. + + Creates 'observation.images' list from specified image feature keys. + + Args: + batch: Input batch containing image feature tensors + """ if self.config.image_features: - batch = dict(batch) # shallow copy to avoid modifying original - batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch["observation.images"] = [ + batch[key] for key in self.config.image_features + ] - # Temporal ensembling logic. - if self.config.temporal_ensemble_coeff is not None: - actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] - action = self.temporal_ensembler.update(actions) - return action - - # Action queue logic for n_action_steps > 1. - if len(self._action_queue) == 0: - actions = self.model(batch)[0][:, : self.config.n_action_steps] - actions = self.unnormalize_outputs({"action": actions})["action"] - self._action_queue.extend(actions.transpose(0, 1)) + def _apply_temporal_ensembling(self, batch: dict[str, Tensor]) -> Tensor: + """Apply temporal ensembling to model outputs. + + Args: + batch: Processed input batch + + Returns: + Temporally ensembled action tensor + """ + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({"action": actions})["action"] + return self.temporal_ensembler.update(actions) + + def _get_queued_action(self, batch: dict[str, Tensor]) -> Tensor: + """Retrieve action from queue or generate new actions if queue is empty. + + Args: + batch: Processed input batch + + Returns: + Next action from the queue + """ + if not self._action_queue: + self._refill_action_queue(batch) return self._action_queue.popleft() + def _refill_action_queue(self, batch: dict[str, Tensor]) -> None: + """Generate and queue multiple actions when queue is empty. + + Args: + batch: Processed input batch + """ + actions = self.model(batch)[0][:, :self.config.n_action_steps] + actions = self.unnormalize_outputs({"action": actions})["action"] + self._action_queue.extend(actions.transpose(0, 1)) + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, Optional[tuple[Tensor, Tensor]]]: - """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) - if self.config.image_features: - batch = dict(batch) - batch["observation.images"] = [batch[key] for key in self.config.image_features] + """Run the batch through the model and compute the loss for training/validation. + Args: + batch: Input batch containing observations, actions, and padding masks + + Returns: + Tuple containing total loss and a dictionary of loss components + """ + batch = self._prepare_input_batch(batch) batch = self.normalize_targets(batch) + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + l1_loss = self._compute_l1_loss(batch, actions_hat) + + loss_dict = {"l1_loss": l1_loss.item()} + total_loss = l1_loss + + if self.config.use_vae: + kld_loss = self._compute_kl_divergence(mu_hat, log_sigma_x2_hat) + loss_dict["kld_loss"] = kld_loss.item() + total_loss += kld_loss * self.config.kl_weight + + return total_loss, loss_dict + + def _prepare_input_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare input batch by normalizing inputs and formatting image features. + + Args: + batch: Raw input batch + + Returns: + Processed batch with normalized inputs and formatted image features + """ + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # Create shallow copy to avoid modifying original + batch["observation.images"] = [ + batch[key] for key in self.config.image_features + ] + return batch - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + def _compute_l1_loss(self, batch: dict[str, Tensor], actions_hat: Tensor) -> Tensor: + """Calculate L1 loss between predicted and ground truth actions. + + Args: + batch: Processed input batch containing 'action' and 'action_is_pad' + actions_hat: Predicted actions from the model + + Returns: + L1 loss tensor with padding mask applied + """ + return ( + F.l1_loss(batch["action"], actions_hat, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) ).mean() - loss_dict = {"l1_loss": l1_loss.item()} - if self.config.use_vae: - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld.item() - loss = l1_loss + mean_kld * self.config.kl_weight - else: - loss = l1_loss + def _compute_kl_divergence(self, mu_hat: Tensor, log_sigma_x2_hat: Tensor) -> Tensor: + """Calculate KL divergence loss for VAE regularization. - return loss, loss_dict + Args: + mu_hat: Predicted mean values + log_sigma_x2_hat: Predicted log variance values + + Returns: + KL divergence loss tensor + """ + return ( + -0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - log_sigma_x2_hat.exp()) + ).sum(-1).mean() class ACTTemporalEnsembler: @@ -207,25 +317,53 @@ def update(self, actions: Tensor) -> Tensor: Args: actions: Tensor of shape (batch, chunk_size, action_dim) with new actions. + Returns: - Tensor with the ensembled action. + Tensor with the ensembled action of shape (batch, action_dim) """ - self.ensemble_weights = self.ensemble_weights.to(device=actions.device) - self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + self._move_weights_to_device(actions.device) + if self.ensembled_actions is None: - self.ensembled_actions = actions.clone() - self.ensembled_actions_count = torch.ones( - (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device - ) + self._initialize_ensemble(actions) else: - self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] - self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] - self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] - self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) - self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) - self.ensembled_actions_count = torch.cat( - [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] - ) + self._update_existing_ensemble(actions) + + return self._prepare_next_action() + + def _move_weights_to_device(self, device: torch.device) -> None: + """Move ensemble weights to the specified device.""" + self.ensemble_weights = self.ensemble_weights.to(device=device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=device) + + def _initialize_ensemble(self, actions: Tensor) -> None: + """Initialize ensemble storage for first update.""" + self.ensembled_actions = actions.clone() + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), + dtype=torch.long, + device=actions.device + ) + + def _update_existing_ensemble(self, actions: Tensor) -> None: + """Update existing ensemble with new actions using exponential weighting.""" + # Apply exponential moving average update + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + + # Update counters and append new action + self.ensembled_actions_count = torch.clamp( + self.ensembled_actions_count + 1, + max=self.chunk_size + ) + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat([ + self.ensembled_actions_count, + torch.ones_like(self.ensembled_actions_count[-1:]) + ]) + + def _prepare_next_action(self) -> Tensor: + """Extract next action and update ensemble storage.""" action = self.ensembled_actions[:, 0] self.ensembled_actions = self.ensembled_actions[:, 1:] self.ensembled_actions_count = self.ensembled_actions_count[1:] @@ -242,80 +380,118 @@ class ACT(nn.Module): def __init__(self, config: ACTConfig): """ Initializes the Action Chunking Transformer model. - + Args: - config: An instance of ACTConfig containing model configuration parameters. + config: Configuration object containing model parameters """ super().__init__() self.config = config - if self.config.use_vae: - self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) - self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) - if self.config.robot_state_feature: - self.vae_encoder_robot_state_input_proj = nn.Linear( - self.config.robot_state_feature.shape[0], config.dim_model - ) - self.vae_encoder_action_input_proj = nn.Linear( - self.config.action_feature.shape[0], - config.dim_model, - ) - self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) - num_input_token_encoder = 1 + config.chunk_size - if self.config.robot_state_feature: - num_input_token_encoder += 1 - self.register_buffer( - "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), + self._init_vae_components() + self._init_task_embedding() + self._init_image_backbone() + self._init_encoder_decoder() + self._init_input_projections() + self._init_position_embeddings() + self._init_action_head() + self._reset_parameters() + + def _init_vae_components(self) -> None: + """Initialize VAE-specific components if enabled in configuration.""" + if not self.config.use_vae: + return + + self.vae_encoder = ACTEncoder(self.config, is_vae_encoder=True) + self.vae_encoder_cls_embed = nn.Embedding(1, self.config.dim_model) + + if self.config.robot_state_feature: + self.vae_encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], self.config.dim_model ) + + self.vae_encoder_action_input_proj = nn.Linear( + self.config.action_feature.shape[0], self.config.dim_model + ) + self.vae_encoder_latent_output_proj = nn.Linear( + self.config.dim_model, self.config.latent_dim * 2 + ) + + num_tokens = 1 + self.config.chunk_size # CLS token + action tokens + if self.config.robot_state_feature: + num_tokens += 1 + self.register_buffer( + "vae_encoder_pos_enc", + create_sinusoidal_pos_embedding(num_tokens, self.config.dim_model).unsqueeze(0), + ) + + def _init_task_embedding(self) -> None: + """Initialize task embedding layer.""" self.task_embedding = nn.Embedding( - config.task_vocab_size, - config.dim_model + self.config.task_vocab_size, + self.config.dim_model ) - n_1d_tokens = 1 - if config.robot_state_feature: - n_1d_tokens += 1 - if config.env_state_feature: - n_1d_tokens += 1 - self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) - if self.config.image_features: - backbone_model = getattr(torchvision.models, config.vision_backbone)( - replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], - weights=config.pretrained_backbone_weights, - norm_layer=FrozenBatchNorm2d, - ) - self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + def _init_image_backbone(self) -> None: + """Initialize vision backbone if image features are enabled.""" + if not self.config.image_features: + return + + backbone_model = getattr(torchvision.models, self.config.vision_backbone)( + replace_stride_with_dilation=[False, False, self.config.replace_final_stride_with_dilation], + weights=self.config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) - self.encoder = ACTEncoder(config) - self.decoder = ACTDecoder(config) + def _init_encoder_decoder(self) -> None: + """Initialize main encoder and decoder components.""" + self.encoder = ACTEncoder(self.config) + self.decoder = ACTDecoder(self.config) + def _init_input_projections(self) -> None: + """Initialize input projection layers for different modalities.""" if self.config.robot_state_feature: self.encoder_robot_state_input_proj = nn.Linear( - self.config.robot_state_feature.shape[0], config.dim_model + self.config.robot_state_feature.shape[0], self.config.dim_model ) + if self.config.env_state_feature: self.encoder_env_state_input_proj = nn.Linear( - self.config.env_state_feature.shape[0], config.dim_model + self.config.env_state_feature.shape[0], self.config.dim_model ) - self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) + + self.encoder_latent_input_proj = nn.Linear( + self.config.latent_dim, self.config.dim_model + ) + if self.config.image_features: + backbone_out_features = getattr(torchvision.models, self.config.vision_backbone)().fc.in_features self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, config.dim_model, kernel_size=1 + backbone_out_features, self.config.dim_model, kernel_size=1 ) - n_1d_tokens = 1 + + def _init_position_embeddings(self) -> None: + """Initialize positional embeddings for different components.""" + # Encoder 1D feature positional embedding + n_1d_tokens = 1 # Task embedding if self.config.robot_state_feature: n_1d_tokens += 1 if self.config.env_state_feature: n_1d_tokens += 1 - self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, self.config.dim_model) + + # Image feature positional embedding if self.config.image_features: - self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) - - self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) - self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) - - self._reset_parameters() + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(self.config.dim_model // 2) + + # Decoder positional embedding + self.decoder_pos_embed = nn.Embedding(self.config.chunk_size, self.config.dim_model) + + def _init_action_head(self) -> None: + """Initialize final action prediction layer.""" + self.action_head = nn.Linear( + self.config.dim_model, self.config.action_feature.shape[0] + ) def _reset_parameters(self): """Xavier-uniform initialization of the transformer parameters as in the original code.""" @@ -355,111 +531,206 @@ def _get_batch_size(self, batch: dict[str, Tensor]) -> int: return batch["observation.environment_state"].shape[0] def _prepare_latent(self, batch: dict[str, Tensor], batch_size: int) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: - """Prepare the latent token for transformer encoder. - + """Prepare latent token for transformer encoder with VAE handling. + + Args: + batch: Input batch containing observations and actions + batch_size: Batch size of the input data + Returns: - latent_sample: Tensor of shape (batch_size, latent_dim) - mu: Mean of the latent distribution (or None) - log_sigma_x2: Log variance (or None) + latent_sample: Sampled latent representation (batch_size, latent_dim) + mu: Latent distribution mean (or None if not using VAE) + log_sigma_x2: Latent log variance (or None if not using VAE) """ - if self.config.use_vae and "action" in batch: - cls_embed = einops.repeat( - self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size - ) - if self.config.robot_state_feature: - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) - robot_state_embed = robot_state_embed.unsqueeze(1) - action_embed = self.vae_encoder_action_input_proj(batch["action"]) - if self.config.robot_state_feature: - vae_encoder_input = [cls_embed, robot_state_embed, action_embed] - else: - vae_encoder_input = [cls_embed, action_embed] - vae_encoder_input = torch.cat(vae_encoder_input, axis=1) - pos_embed = self.vae_encoder_pos_enc.clone().detach() - cls_joint_is_pad = torch.full( - (batch_size, 2 if self.config.robot_state_feature else 1), - False, - device=batch["observation.state"].device, - ) - key_padding_mask = torch.cat( - [cls_joint_is_pad, batch["action_is_pad"]], axis=1 - ) - cls_token_out = self.vae_encoder( - vae_encoder_input.permute(1, 0, 2), - pos_embed=pos_embed.permute(1, 0, 2), - key_padding_mask=key_padding_mask, - )[0] - latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) - mu = latent_pdf_params[:, : self.config.latent_dim] - log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] - latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) - return latent_sample, mu, log_sigma_x2 - else: - mu = log_sigma_x2 = None - latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device - ) - return latent_sample, mu, log_sigma_x2 + if not self._should_use_vae(batch): + return self._create_dummy_latent(batch_size), None, None + + vae_input = self._create_vae_input_embeddings(batch, batch_size) + pos_embed, key_padding_mask = self._prepare_vae_positioning(batch, batch_size) + + cls_token_out = self._run_vae_encoder(vae_input, pos_embed, key_padding_mask) + mu, log_sigma_x2 = self._extract_latent_parameters(cls_token_out) + latent_sample = self._sample_latent_distribution(mu, log_sigma_x2) + + return latent_sample, mu, log_sigma_x2 + + def _should_use_vae(self, batch: dict[str, Tensor]) -> bool: + """Check if VAE should be used for current batch.""" + return self.config.use_vae and "action" in batch + + def _create_dummy_latent(self, batch_size: int) -> Tensor: + """Create zero-initialized latent tensor for non-VAE cases.""" + return torch.zeros( + (batch_size, self.config.latent_dim), + dtype=torch.float32, + device=self._get_device() + ) + + def _create_vae_input_embeddings(self, batch: dict[str, Tensor], batch_size: int) -> Tensor: + """Construct VAE encoder input embeddings.""" + components = [self._get_cls_embedding(batch_size)] + + if self.config.robot_state_feature: + components.append(self._get_robot_state_embedding(batch)) + + components.append(self._get_action_embedding(batch)) + return torch.cat(components, dim=1) + + def _get_cls_embedding(self, batch_size: int) -> Tensor: + """Generate CLS token embedding.""" + return einops.repeat( + self.vae_encoder_cls_embed.weight, + "1 d -> b 1 d", + b=batch_size + ) + + def _get_robot_state_embedding(self, batch: dict[str, Tensor]) -> Tensor: + """Generate robot state embedding if configured.""" + return self.vae_encoder_robot_state_input_proj( + batch["observation.state"] + ).unsqueeze(1) + + def _get_action_embedding(self, batch: dict[str, Tensor]) -> Tensor: + """Generate action sequence embeddings.""" + return self.vae_encoder_action_input_proj(batch["action"]) + + def _prepare_vae_positioning(self, batch: dict[str, Tensor], batch_size: int) -> tuple[Tensor, Tensor]: + """Prepare positional embeddings and attention mask for VAE encoder.""" + pos_embed = self.vae_encoder_pos_enc.expand(batch_size, -1, -1) + return pos_embed.permute(1, 0, 2), self._create_key_padding_mask(batch, batch_size) + + def _create_key_padding_mask(self, batch: dict[str, Tensor], batch_size: int) -> Tensor: + """Create attention mask for padded action sequences.""" + cls_joint_mask = torch.full( + (batch_size, self._get_cls_joint_size()), + False, + device=self._get_device() + ) + return torch.cat([cls_joint_mask, batch["action_is_pad"]], dim=1) + + def _get_cls_joint_size(self) -> int: + """Calculate CLS joint size based on configuration.""" + return 2 if self.config.robot_state_feature else 1 + + def _run_vae_encoder(self, vae_input: Tensor, pos_embed: Tensor, mask: Tensor) -> Tensor: + """Execute VAE encoder forward pass.""" + return self.vae_encoder( + vae_input.permute(1, 0, 2), + pos_embed=pos_embed, + key_padding_mask=mask + )[0] + + def _extract_latent_parameters(self, cls_token_out: Tensor) -> tuple[Tensor, Tensor]: + """Extract mean and log variance from encoder output.""" + params = self.vae_encoder_latent_output_proj(cls_token_out) + return params.chunk(2, dim=-1) + + def _sample_latent_distribution(self, mu: Tensor, log_sigma_x2: Tensor) -> Tensor: + """Sample latent variable using reparameterization trick.""" + return mu + (log_sigma_x2 / 2).exp() * torch.randn_like(mu) + + def _get_device(self) -> torch.device: + """Get current device from model parameters.""" + return next(self.parameters()).device def _prepare_encoder_inputs(self, batch: dict[str, Tensor], latent_sample: Tensor, batch_size: int) -> tuple[Tensor, Tensor]: - """Prepare tokens and positional embeddings for transformer encoder. - + """Prepare input tokens and positional embeddings for transformer encoder. + + Args: + batch: Input batch containing observations and task information + latent_sample: Latent representation from VAE + batch_size: Batch size of the input data + Returns: - encoder_in_tokens: Tensor of shape (total_tokens, batch_size, dim_model) - encoder_in_pos_embed: Tensor of shape (total_tokens, 1, dim_model) + encoder_in_tokens: Combined input tokens (total_tokens, batch_size, dim_model) + encoder_in_pos_embed: Corresponding positional embeddings (total_tokens, 1, dim_model) """ - tokens = [] - pos_embeds = [] + tokens, pos_embeds = [], [] + + self._add_latent_token(tokens, pos_embeds, latent_sample) + self._add_robot_state_token(tokens, pos_embeds, batch) + self._add_task_tokens(tokens, pos_embeds, batch) + self._add_env_state_token(tokens, batch) + self._add_image_tokens(tokens, pos_embeds, batch) + + tokens = self._adjust_batch_dimensions(tokens, batch_size) + pos_embeds = self._adjust_batch_dimensions(pos_embeds, batch_size) + + return torch.cat(tokens, dim=0), torch.cat(pos_embeds, dim=0) + + def _add_latent_token(self, tokens: list[Tensor], pos_embeds: list[Tensor], latent_sample: Tensor) -> None: + """Add latent token and its positional embedding.""" latent_token = self.encoder_latent_input_proj(latent_sample).unsqueeze(0) tokens.append(latent_token) - pos_embeds.append(self.encoder_1d_feature_pos_embed.weight[0].unsqueeze(0).unsqueeze(1)) - if self.config.robot_state_feature: - robot_state_token = self.encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(0) - tokens.append(robot_state_token) - pos_embeds.append(self.encoder_1d_feature_pos_embed.weight[1].unsqueeze(0).unsqueeze(1)) - if "task.input_ids" in batch: - task_embeds = self.task_embedding(batch["task.input_ids"]) - task_embeds = einops.rearrange(task_embeds, "b s d -> s b d") - tokens.append(task_embeds) - task_pos = create_sinusoidal_pos_embedding(task_embeds.size(0), self.config.dim_model).to( - device=task_embeds.device - ).unsqueeze(1) - pos_embeds.append(task_pos) - if self.config.env_state_feature: - tokens.append(self.encoder_env_state_input_proj(batch["observation.environment_state"]).unsqueeze(0)) - if self.config.image_features: - for img in batch["observation.images"]: - cam_features = self.backbone(img)["feature_map"] - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) - cam_features = self.encoder_img_feat_input_proj(cam_features) - cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") - cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") - tokens.append(cam_features) - pos_embeds.append(cam_pos_embed) - - # DEBUG: shows original shapes of the tokens - # print(f"[DEBUG] token shapes before adjustment: {[t.shape for t in tokens]}") - # print(f"[DEBUG] pos_embeds shapes before adjustment: {[t.shape for t in pos_embeds]}") - - # Helper function to adjust the batch dimension - def adjust_batch_dim(t: torch.Tensor, desired_bs: int) -> torch.Tensor: - if t.shape[1] == desired_bs: - return t - elif t.shape[1] == 1: - return t.expand(t.shape[0], desired_bs, t.shape[2]) - else: - raise ValueError(f"Incompatible batch size: expected {desired_bs}, but got {t.shape[1]} in tensor {t}") - - # Adjust the tokens and pos_embeds tensors to have the expected batch size - tokens = [adjust_batch_dim(t, batch_size) for t in tokens] - pos_embeds = [adjust_batch_dim(t, batch_size) for t in pos_embeds] - - # print(f"[DEBUG] adjusted token shapes: {[t.shape for t in tokens]}") - # print(f"[DEBUG] adjusted pos_embeds shapes: {[t.shape for t in pos_embeds]}") - - encoder_in_tokens = torch.cat(tokens, dim=0) - encoder_in_pos_embed = torch.cat(pos_embeds, dim=0) - return encoder_in_tokens, encoder_in_pos_embed + pos_embeds.append(self._get_1d_pos_embedding(0)) + + def _add_robot_state_token(self, tokens: list[Tensor], pos_embeds: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add robot state token if configured.""" + if not self.config.robot_state_feature: + return + robot_token = self.encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(0) + tokens.append(robot_token) + pos_embeds.append(self._get_1d_pos_embedding(1)) + + def _add_task_tokens(self, tokens: list[Tensor], pos_embeds: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add task tokens and their positional embeddings if present.""" + if "task.input_ids" not in batch: + return + task_embeds = self.task_embedding(batch["task.input_ids"]) + task_embeds = einops.rearrange(task_embeds, "b s d -> s b d") + tokens.append(task_embeds) + task_pos = self._create_task_pos_embedding(task_embeds) + pos_embeds.append(task_pos) + + def _add_env_state_token(self, tokens: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add environment state token if configured.""" + if not self.config.env_state_feature: + return + env_token = self.encoder_env_state_input_proj(batch["observation.environment_state"]).unsqueeze(0) + tokens.append(env_token) + + def _add_image_tokens(self, tokens: list[Tensor], pos_embeds: list[Tensor], batch: dict[str, Tensor]) -> None: + """Add image tokens and positional embeddings if configured.""" + if not self.config.image_features: + return + for img in batch["observation.images"]: + features = self.backbone(img)["feature_map"] + pos_embed = self.encoder_cam_feat_pos_embed(features).to(dtype=features.dtype) + features = self._process_image_features(features) + pos_embed = self._process_image_pos_embed(pos_embed) + tokens.append(features) + pos_embeds.append(pos_embed) + + def _process_image_features(self, features: Tensor) -> Tensor: + """Project and reshape image features for encoder input.""" + features = self.encoder_img_feat_input_proj(features) + return einops.rearrange(features, "b c h w -> (h w) b c") + + def _process_image_pos_embed(self, pos_embed: Tensor) -> Tensor: + """Reshape positional embeddings for image features.""" + return einops.rearrange(pos_embed, "b c h w -> (h w) b c") + + def _adjust_batch_dimensions(self, tensors: list[Tensor], batch_size: int) -> list[Tensor]: + """Ensure all tensors have correct batch dimension.""" + return [self._adjust_tensor(t, batch_size) for t in tensors] + + def _adjust_tensor(self, tensor: Tensor, batch_size: int) -> Tensor: + """Adjust batch dimension of a single tensor.""" + if tensor.shape[1] == batch_size: + return tensor + if tensor.shape[1] == 1: + return tensor.expand(-1, batch_size, -1) + raise ValueError(f"Unexpected batch size {tensor.shape[1]}, expected {batch_size}") + + def _get_1d_pos_embedding(self, index: int) -> Tensor: + """Retrieve 1D positional embedding by index.""" + return self.encoder_1d_feature_pos_embed.weight[index].unsqueeze(0).unsqueeze(1) + + def _create_task_pos_embedding(self, task_embeds: Tensor) -> Tensor: + """Create sinusoidal positional embeddings for task tokens.""" + seq_len = task_embeds.size(0) + pos_embed = create_sinusoidal_pos_embedding(seq_len, self.config.dim_model) + return pos_embed.unsqueeze(1).to(device=task_embeds.device) def _prepare_decoder_input(self, batch_size: int, dtype, device) -> Tensor: """Prepare the zero-initialized input for the decoder. @@ -491,20 +762,20 @@ def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() - def forward(self, x: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, tensor: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: """ Pass the input through encoder layers and apply normalization. Args: - x: Input tensor of shape (seq_length, batch_size, embedding_dim). + tensor: Input tensor of shape (seq_length, batch_size, embedding_dim). pos_embed: Optional positional embeddings. key_padding_mask: Optional mask for padded tokens. Returns: Encoded tensor. """ for layer in self.layers: - x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) - return self.norm(x) + tensor = layer(tensor, pos_embed=pos_embed, key_padding_mask=key_padding_mask) + return self.norm(tensor) class ACTEncoderLayer(nn.Module): @@ -527,33 +798,33 @@ def __init__(self, config: ACTConfig): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def _apply_self_attention(self, x: Tensor, pos_embed: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: + def _apply_self_attention(self, tensor: Tensor, pos_embed: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: """ Applies self-attention to the input tensor. Args: - x: Input tensor. + tensor: Input tensor. pos_embed: Optional positional embeddings. key_padding_mask: Optional mask for padded tokens. Returns: Tensor after self-attention and dropout. """ - norm_x = self.norm1(x) if self.pre_norm else x + norm_tensor = self.norm1(tensor) if self.pre_norm else tensor if pos_embed is not None: - norm_x = norm_x + pos_embed - attn_output = self.self_attn(norm_x, norm_x, value=norm_x, key_padding_mask=key_padding_mask)[0] - return x + self.dropout1(attn_output) + norm_tensor = norm_tensor + pos_embed + attn_output = self.self_attn(norm_tensor, norm_tensor, value=norm_tensor, key_padding_mask=key_padding_mask)[0] + return tensor + self.dropout1(attn_output) - def _apply_feed_forward(self, x: Tensor) -> Tensor: + def _apply_feed_forward(self, tensor: Tensor) -> Tensor: """ Applies the feed-forward network. Args: - x: Input tensor. + tensor: Input tensor. Returns: Tensor after feed-forward processing. """ - ff_output = self.linear2(self.dropout(self.activation(self.linear1(x)))) + ff_output = self.linear2(self.dropout(self.activation(self.linear1(tensor)))) return ff_output def forward(self, input_tensor: Tensor, pos_embed: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: @@ -590,7 +861,7 @@ def __init__(self, config: ACTConfig): def forward( self, - x: Tensor, + tensor: Tensor, encoder_out: Tensor, decoder_pos_embed: Optional[Tensor] = None, encoder_pos_embed: Optional[Tensor] = None, @@ -599,7 +870,7 @@ def forward( Pass the input through decoder layers and apply normalization. Args: - x: Input tensor of shape (decoder_seq_length, batch_size, embedding_dim). + tensor: Input tensor of shape (decoder_seq_length, batch_size, embedding_dim). encoder_out: Encoder output tensor. decoder_pos_embed: Optional decoder positional embeddings. encoder_pos_embed: Optional encoder positional embeddings. @@ -607,8 +878,8 @@ def forward( Decoded tensor. """ for layer in self.layers: - x = layer(x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed) - return self.norm(x) if self.norm is not None else x + tensor = layer(tensor, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed) + return self.norm(tensor) if self.norm is not None else tensor class ACTDecoderLayer(nn.Module): @@ -634,51 +905,51 @@ def __init__(self, config: ACTConfig): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def _apply_self_attention(self, x: Tensor, decoder_pos_embed: Optional[Tensor]) -> Tensor: + def _apply_self_attention(self, tensor: Tensor, decoder_pos_embed: Optional[Tensor]) -> Tensor: """ Applies self-attention for the decoder. Args: - x: Input tensor. + tensor: Input tensor. decoder_pos_embed: Optional positional embeddings. Returns: Tensor after self-attention. """ - norm_x = self.norm1(x) if self.pre_norm else x + norm_tensor = self.norm1(tensor) if self.pre_norm else tensor if decoder_pos_embed is not None: - norm_x = norm_x + decoder_pos_embed - attn_output = self.self_attn(norm_x, norm_x, value=norm_x)[0] - return x + self.dropout1(attn_output) + norm_tensor = norm_tensor + decoder_pos_embed + attn_output = self.self_attn(norm_tensor, norm_tensor, value=norm_tensor)[0] + return tensor + self.dropout1(attn_output) def _apply_cross_attention( - self, x: Tensor, encoder_output: Tensor, decoder_pos_embed: Optional[Tensor], encoder_pos_embed: Optional[Tensor] + self, tensor: Tensor, encoder_output: Tensor, decoder_pos_embed: Optional[Tensor], encoder_pos_embed: Optional[Tensor] ) -> Tensor: """ Applies cross-attention between decoder input and encoder output. Args: - x: Decoder input tensor. + tensor: Decoder input tensor. encoder_output: Encoder output tensor. decoder_pos_embed: Optional decoder positional embeddings. encoder_pos_embed: Optional encoder positional embeddings. Returns: Tensor after cross-attention. """ - query = x if decoder_pos_embed is None else x + decoder_pos_embed + query = tensor if decoder_pos_embed is None else tensor + decoder_pos_embed key = encoder_output if encoder_pos_embed is None else encoder_output + encoder_pos_embed cross_attn_out = self.multihead_attn(query=query, key=key, value=encoder_output)[0] - return x + self.dropout2(cross_attn_out) + return tensor + self.dropout2(cross_attn_out) - def _apply_feed_forward(self, x: Tensor) -> Tensor: + def _apply_feed_forward(self, tensor: Tensor) -> Tensor: """ Applies the feed-forward network. Args: - x: Input tensor. + tensor: Input tensor. Returns: Tensor after feed-forward processing. """ - ff_out = self.linear2(self.dropout(self.activation(self.linear1(x)))) + ff_out = self.linear2(self.dropout(self.activation(self.linear1(tensor)))) return ff_out def forward( @@ -749,20 +1020,20 @@ def __init__(self, dimension: int): self._eps = 1e-6 self._temperature = 10000 - def forward(self, x: Tensor) -> Tensor: + def forward(self, tensor: Tensor) -> Tensor: """ Args: - x: A (B, C, H, W) tensor. + tensor: A (B, C, H, W) tensor. Returns: Tensor of shape (1, C, H, W) with sinusoidal positional embeddings. """ - not_mask = torch.ones_like(x[0, :1]) # (1, H, W) + not_mask = torch.ones_like(tensor[0, :1]) # (1, H, W) y_range = not_mask.cumsum(1, dtype=torch.float32) x_range = not_mask.cumsum(2, dtype=torch.float32) y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi inverse_frequency = self._temperature ** ( - 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=tensor.device) // 2) / self.dimension ) x_range = x_range.unsqueeze(-1) / inverse_frequency y_range = y_range.unsqueeze(-1) / inverse_frequency