From 828526546a7cd6c659d401e0233b9606fe3dea51 Mon Sep 17 00:00:00 2001 From: mshukor Date: Mon, 9 Dec 2024 07:16:15 +0100 Subject: [PATCH] Refactoring and support PaliGemma (#557) Co-authored-by: mshukor Co-authored-by: mshukor Co-authored-by: mshukor Co-authored-by: mshukor --- .gitignore | 3 + lerobot/common/datasets/lerobot_dataset.py | 1 - lerobot/common/datasets/mock.py | 65 --- lerobot/common/datasets/utils.py | 1 - lerobot/common/policies/act/modeling_act.py | 3 +- .../policies/vla/configuration_qwen2_vl.py | 197 ------- .../common/policies/vla/configuration_vla.py | 61 +- lerobot/common/policies/vla/debug.py | 124 ---- .../common/policies/vla/modeling_language.py | 532 ------------------ .../common/policies/vla/modeling_vision.py | 246 -------- lerobot/common/policies/vla/modeling_vla.py | 444 ++++++--------- lerobot/configs/policy/vla.yaml | 44 +- lerobot/scripts/eval.py | 12 +- lerobot/scripts/train.py | 7 +- .../templates/visualize_dataset_template.html | 2 +- 15 files changed, 259 insertions(+), 1483 deletions(-) delete mode 100644 lerobot/common/datasets/mock.py delete mode 100644 lerobot/common/policies/vla/configuration_qwen2_vl.py delete mode 100644 lerobot/common/policies/vla/debug.py delete mode 100644 lerobot/common/policies/vla/modeling_language.py delete mode 100644 lerobot/common/policies/vla/modeling_vision.py diff --git a/.gitignore b/.gitignore index 0e203a394..67662163e 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,6 @@ dmypy.json # Cython debug symbols cython_debug/ + +# launcing examples with slurm +slurm/ diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index eb76f78d6..887177d1b 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -143,7 +143,6 @@ def __getitem__(self, idx): self.delta_timestamps, self.tolerance_s, ) - if self.video: item = load_from_videos( item, diff --git a/lerobot/common/datasets/mock.py b/lerobot/common/datasets/mock.py deleted file mode 100644 index 5a5595d23..000000000 --- a/lerobot/common/datasets/mock.py +++ /dev/null @@ -1,65 +0,0 @@ -import unittest -from pathlib import Path -from lerobot_dataset import MultiLeRobotDataset - -class TestMultiLeRobotDataset(unittest.TestCase): - def setUp(self): - # Define the datasets to use - self.dataset_repo_ids = [ - "lerobot/aloha_sim_insertion_human", - "lerobot/aloha_static_vinh_cup" - ] - self.dataset = MultiLeRobotDataset( - repo_ids=self.dataset_repo_ids, - # Replace with your local path or None for Hugging Face Hub - split="train", - image_transforms=None, # Pass your transforms if any - delta_timestamps=None, - ) - - def test_initialization(self): - # Check if datasets were initialized correctly - self.assertEqual(len(self.dataset.repo_ids), 2) - self.assertEqual(self.dataset.repo_ids, self.dataset_repo_ids) - - def test_num_samples(self): - # Check the total number of samples - self.assertGreater(len(self.dataset), 0) - - def test_num_episodes(self): - # Check the total number of episodes - self.assertGreater(self.dataset.num_episodes, 0) - - def test_fps(self): - # Check that FPS is correctly returned and is consistent - fps = self.dataset.fps - self.assertGreater(fps, 0) - - def test_video_property(self): - # Check if video loading is correctly handled - self.assertIsInstance(self.dataset.video, bool) - - def test_getitem(self): - # Test accessing a few samples to see if they are returned correctly - for i in range(5): - sample = self.dataset[i] - self.assertIsInstance(sample, dict) - self.assertIn("dataset_index", sample) # Check that dataset index is included - breakpoint() - - def test_camera_keys(self): - # Test that camera keys are returned correctly - camera_keys = self.dataset.camera_keys - self.assertIsInstance(camera_keys, list) - self.assertGreater(len(camera_keys), 0) - - def test_video_frame_keys(self): - # Test that video frame keys are returned correctly - video_frame_keys = self.dataset.video_frame_keys - if self.dataset.video: - self.assertIsInstance(video_frame_keys, list) - else: - self.assertEqual(len(video_frame_keys), 0) - -if __name__ == "__main__": - unittest.main() diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index d6aef15f5..8a831d546 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -139,7 +139,6 @@ def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datas else: safe_version = get_hf_dataset_safe_version(repo_id, version) hf_dataset = load_dataset(repo_id, revision=safe_version, split=split) - hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3e5230f0e..4708d6aa8 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -22,7 +22,7 @@ import math from collections import deque from itertools import chain -from typing import Callable +from typing import Any, Callable import einops import numpy as np @@ -584,6 +584,7 @@ def forward( encoder_out: Tensor, decoder_pos_embed: Tensor | None = None, encoder_pos_embed: Tensor | None = None, + **kwargs: Any, ) -> Tensor: for layer in self.layers: x = layer( diff --git a/lerobot/common/policies/vla/configuration_qwen2_vl.py b/lerobot/common/policies/vla/configuration_qwen2_vl.py deleted file mode 100644 index 1b675234f..000000000 --- a/lerobot/common/policies/vla/configuration_qwen2_vl.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Qwen2VL model configuration""" - -import os -from typing import Union - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import ( - logging, # Using standard Python logging module instead of `transformers.utils.logging` -) - -logger = logging.get_logger(__name__) - - -def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: set | None = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get( - "rope_type", rope_scaling.get("type", None) - ) # BC: "rope_type" was originally "type" - required_keys = {"rope_type"} - received_keys = set(rope_scaling.keys()) - # _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - -# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. -ROPE_VALIDATION_FUNCTIONS = { - "default": _validate_default_rope_parameters, - # "linear": _validate_linear_scaling_rope_parameters, - # "dynamic": _validate_dynamic_scaling_rope_parameters, - # "yarn": _validate_yarn_parameters, - # "longrope": _validate_longrope_parameters, - # "llama3": _validate_llama3_parameters, -} - - -def rope_config_validation(config: PretrainedConfig, ignore_keys: set | None = None): - """ - Validate the RoPE config arguments, given a `PretrainedConfig` object - """ - rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` - if rope_scaling is None: - return - - # BC: "rope_type" was originally "type" - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) - validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) - if validation_fn is not None: - validation_fn(config, ignore_keys=ignore_keys) - else: - logger.warning( - f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" - ) - - -class Qwen2VLVisionConfig(PretrainedConfig): - model_type = "qwen2_vl" - - def __init__( - self, - depth=32, - embed_dim=1280, - hidden_size=3584, - hidden_act="quick_gelu", - mlp_ratio=4, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - - if config_dict.get("model_type") == "qwen2_vl": - config_dict = config_dict["vision_config"] - - if ( - "model_type" in config_dict - and hasattr(cls, "model_type") - and config_dict["model_type"] != cls.model_type - ): - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - - -class Qwen2VLConfig(PretrainedConfig): - r""" - A simplified version of the Qwen2VL model configuration class without the `transformers` dependencies. - """ - - model_type = "qwen2_vl" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=152064, - hidden_size=8192, - intermediate_size=29568, - num_hidden_layers=80, - num_decoder_layers=1, - num_attention_heads=64, - num_key_value_heads=8, - # dim_feedforward = 3200, - hidden_act="silu", - pad_token_id=0, - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-05, - use_cache=True, - tie_word_embeddings=False, - rope_theta=1000000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=80, - attention_dropout=0.0, - vision_config=None, - rope_scaling={"type": "mrope", "mrope_section": [2, 2, 2]}, - pruned_heads=None, - **kwargs, - ): - # Initialize vision config - if isinstance(vision_config, dict): - self.vision_config = Qwen2VLVisionConfig(**vision_config) - elif vision_config is None: - self.vision_config = Qwen2VLVisionConfig() - - # Model hyperparameters - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - self.pad_token_id = pad_token_id - self.pruned_heads = pruned_heads or {} - self.rope_scaling = rope_scaling - self.num_decoder_layers = num_decoder_layers - - if self.rope_scaling is not None and "type" in self.rope_scaling: - if self.rope_scaling["type"] == "mrope": - self.rope_scaling["type"] = "default" - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self, ignore_keys={"mrope_section"}) - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - - # @classmethod - # def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - # # Custom loading logic from a pre-trained model or path - # logger.info(f"Loading pretrained config from {pretrained_model_name_or_path}...") - # # Add custom logic here to load a pretrained configuration - # return cls(**kwargs) diff --git a/lerobot/common/policies/vla/configuration_vla.py b/lerobot/common/policies/vla/configuration_vla.py index 92f7fcb22..0072f991e 100644 --- a/lerobot/common/policies/vla/configuration_vla.py +++ b/lerobot/common/policies/vla/configuration_vla.py @@ -116,16 +116,35 @@ class VLAConfig: ) prompt: str = "Please transfer the cube." - #"Please insert the tube into the socket." + # "Please insert the tube into the socket." # Architecture. + # Action decoder + + action_decoder: dict = field( + default_factory=lambda: { + "name": "act", + "vlm_hidden_dim": 128, + "action_dim": 4, + "n_decoder_layers": 1, + "n_encoder_layers": 4, + "dim_model": 896, + "n_heads": 8, + "dim_feedforward": 3200, + "feedforward_activation": "relu", + "pre_norm": False, + "dropout": 0.1, + "temporal_ensemble_coeff": None, + } + ) + # Language + Main transformer vocab_size: int = 150528 hidden_size: int = 896 - num_decoder_layers: int = 1 - num_attention_heads: int = 8 - intermediate_size: int = 3200 + # n_decoder_layers: int = 1 + # n_heads: int = 8 + # dim_feedforward: int = 3200 hidden_act: str = "silu" pad_token_id: int = 0 initializer_range: float = 0.02 @@ -136,7 +155,7 @@ class VLAConfig: use_sliding_window: bool = False sliding_window = 4096 max_window_layers = 80 - attention_dropout = 0.0 + # dropout = 0.0 rope_scaling: dict = field( default_factory=lambda: { "type": "mrope", @@ -145,18 +164,6 @@ class VLAConfig: ) pruned_heads = None - # Vision encoder - # depth: int = 32 - # embed_dim: int = 1280 - # hidden_size: int = 3584 - # hidden_act: str = "quick_gelu" - # mlp_ratio: int = 4 - # num_heads: int = 16 - # in_channels: int = 3 - # patch_size: int = 14 - # spatial_merge_size: int = 2 - # temporal_patch_size: int = 2 - # attn_implementation: str = "eager" vision_config: dict = field( default_factory=lambda: { "depth": 32, @@ -173,6 +180,26 @@ class VLAConfig: } ) + # Vision-Language Model (Qwen-VL) + vlm_backbone: dict = field( + default_factory=lambda: { + "name": "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "feature_selection": "first_image", + } + ) + use_prompt_template: bool = True + num_img_tokens: int = 598 + + peft_method: str = "lora" + peft_config: dict = field( + default_factory=lambda: { + "r": 4, + "lora_alpha": 16, + "lora_dropout": 0.1, + "target_modules": ["q_proj", "v_proj"], + } + ) + def __post_init__(self): """Input validation (not exhaustive).""" # if not self.vision_backbone.startswith("resnet"): diff --git a/lerobot/common/policies/vla/debug.py b/lerobot/common/policies/vla/debug.py deleted file mode 100644 index 144fd9052..000000000 --- a/lerobot/common/policies/vla/debug.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -from configuration_vla import Qwen2VLConfig, Qwen2VLVisionConfig -from modeling_vla import Qwen2VLModel, VLA, VLAPolicy -from modeling_vision import Qwen2VisionTransformerPretrainedModel - -def test_vla_policy(): - # Define the model configuration - config = Qwen2VLConfig( - vocab_size=30522, # Token vocabulary size - hidden_size=768, # Hidden size for the model - num_hidden_layers=2, # Number of layers in the transformer - input_shapes={ - "observation.state": [128], # Observation state shape - }, - output_shapes={ - "action": [64], # Action output shape - }, - ) - - # Initialize the VLAPolicy - vla_policy = VLAPolicy(config) - - # Create a batch of random input data for testing - batch = { - "input_ids": torch.randint(0, config.vocab_size, (1, 10)), # Random tokenized input (batch_size=1, seq_len=10) - "attention_mask": torch.ones((1, 12), dtype=torch.long), # Attention mask - "observation.state": torch.randn(1, 128), # Random observation state (batch_size=1, state_dim=128) - "action": torch.randn(1, 64), # Random ground-truth action (batch_size=1, action_dim=64) - } - - # Perform a forward pass for training (to calculate loss) - output = vla_policy(batch) - print("Output during training (loss):", output) - - # Perform action selection (no loss calculated, just action prediction) - with torch.no_grad(): - predicted_action = vla_policy.select_action(batch) - print("Predicted Action:", predicted_action) - - - visual = Qwen2VisionTransformerPretrainedModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) - -# Run the test function -if __name__ == "__main__": - test_vla_policy() -''' -def test_model_forward_pass(): - # Define the model configuration - config = Qwen2VLConfig( - vocab_size=30522, - hidden_size=768, - num_hidden_layers=2, - input_shapes={ - "observation.state": [128], # Example observation state shape - }, - output_shapes={ - "action": [64], # Example action shape - }, - ) - # Initialize the VLA model - model = VLA(config) - - # Generate random input data - batch = { - "observation.state": torch.randn(1, 128), # Batch of size 1, observation state with 128 features - "action": torch.randn(1, 64), # Batch of size 1, action with 64 features - } - - input_ids = torch.randint(0, config.vocab_size, (1, 10)) # Random tokenized input, seq length = 10 - attention_mask = torch.ones((1, 12), dtype=torch.long) # Attention mask for the sequence and additional tokens - - # Perform forward pass with the model - with torch.no_grad(): # No gradient needed for testing - output = model( - batch=batch, - input_ids=input_ids, - attention_mask=attention_mask, - ) - - # Check if output has the correct shape - assert output.shape == (1, 64), f"Unexpected output shape: {output.shape}" - - print("VLA model forward pass successful with output shape:", output.shape) - -if __name__ == "__main__": - test_model_forward_pass() - -''' -''' - # Initialize the model - model = Qwen2VLModel(config) - - # Generate random input data - batch = { - "observation.state": torch.randn(1, 128), # Batch of size 1, observation state with 128 features - "action": torch.randn(1, 64), # Batch of size 1, action with 64 features - } - - input_ids = torch.randint(0, config.vocab_size, (1, 10)) # Random tokenized input, seq length = 10 - attention_mask = torch.ones((1, 12), dtype=torch.long) # Attention mask - - # Perform forward pass with the model - with torch.no_grad(): # No gradient needed for testing - output = model( - batch=batch, - input_ids=input_ids, - attention_mask=attention_mask, - return_dict = True, - output_hidden_states = True, - ) - breakpoint() - # Check if output has the correct shape - assert output.last_hidden_state.shape == (1, 12, config.hidden_size), \ - f"Unexpected output shape: {output.last_hidden_state.shape}" - - print("Model forward pass successful with output shape:", output.last_hidden_state.shape) - -if __name__ == "__main__": - test_model_forward_pass() -''' - - diff --git a/lerobot/common/policies/vla/modeling_language.py b/lerobot/common/policies/vla/modeling_language.py deleted file mode 100644 index 23c3f890d..000000000 --- a/lerobot/common/policies/vla/modeling_language.py +++ /dev/null @@ -1,532 +0,0 @@ -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.utils.checkpoint -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache -from transformers.modeling_outputs import ModelOutput -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.utils import logging - -from lerobot.common.policies.vla.configuration_vla import VLAConfig - -logger = logging.get_logger(__name__) - - -@dataclass -class Qwen2VLCausalLMOutputWithPast(ModelOutput): - """ - Base class for Qwen2VL causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None - - -class Qwen2VLRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[VLAConfig] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer( - "inv_freq", inv_freq, persistent=False - ) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if ( - seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len - ): # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids - # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2VLAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: VLAConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - self.rope_scaling = config.rope_scaling - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2VLRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - } # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # Fix precision issues in Qwen2-VL float16 inference - # Replace inf values with zeros in attention weights to prevent NaN propagation - if query_states.dtype == torch.float16: - attn_weights = torch.where( - torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights - ) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -QWEN2_VL_ATTENTION_CLASSES = { - "eager": Qwen2VLAttention, -} - - -class Qwen2VLDecoderLayer(nn.Module): - # TODO(rcadene, dana): update config type VLAConfig - def __init__(self, config: VLAConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - config._attn_implementation = "eager" - self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/lerobot/common/policies/vla/modeling_vision.py b/lerobot/common/policies/vla/modeling_vision.py deleted file mode 100644 index 675e136c9..000000000 --- a/lerobot/common/policies/vla/modeling_vision.py +++ /dev/null @@ -1,246 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.nn import LayerNorm -from transformers.activations import ACT2FN -from transformers.modeling_utils import PreTrainedModel - -from lerobot.common.policies.vla.configuration_vla import VLAConfig - - -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - -class Qwen2VLPreTrainedModel(PreTrainedModel): - config_class = VLAConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_channels: int = 3, - embed_dim: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = embed_dim - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) - return hidden_states - - -class PatchMerger(nn.Module): - def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: - super().__init__() - self.hidden_size = context_dim * (spatial_merge_size**2) - self.ln_q = LayerNorm(context_dim, eps=1e-6) - self.mlp = nn.Sequential( - nn.Linear(self.hidden_size, self.hidden_size), - nn.GELU(), - nn.Linear(self.hidden_size, dim), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) - return x - - -class VisionMlp(nn.Module): - def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.act = ACT2FN[hidden_act] - self.fc2 = nn.Linear(hidden_dim, dim) - - def forward(self, x) -> torch.Tensor: - return self.fc2(self.act(self.fc1(x))) - - -class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = ( - self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - ) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -QWEN2_VL_VISION_ATTENTION_CLASSES = {"eager": VisionAttention} - - -class Qwen2VLVisionBlock(nn.Module): - def __init__(self, config, attn_implementation: str = "sdpa") -> None: - super().__init__() - self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) - self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) - mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) - - self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.embed_dim, num_heads=config.num_heads - ) - self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) - - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb - ) - hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states - - -class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): - config_class = VLAConfig - _no_split_modules = ["Qwen2VLVisionBlock"] - - def __init__(self, config) -> None: - super().__init__(config) - self.spatial_merge_size = config.spatial_merge_size - - self.patch_embed = PatchEmbed( - patch_size=config.patch_size, - temporal_patch_size=config.temporal_patch_size, - in_channels=config.in_channels, - embed_dim=config.embed_dim, - ) - - head_dim = config.embed_dim // config.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) - - self.blocks = nn.ModuleList( - [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) - self.merger = PatchMerger( - dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size - ) - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32 - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) - - return self.merger(hidden_states) diff --git a/lerobot/common/policies/vla/modeling_vla.py b/lerobot/common/policies/vla/modeling_vla.py index 45d615c6f..93831edfe 100644 --- a/lerobot/common/policies/vla/modeling_vla.py +++ b/lerobot/common/policies/vla/modeling_vla.py @@ -1,16 +1,18 @@ -import inspect from collections import deque -from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin +from omegaconf import OmegaConf +from peft import LoraConfig, TaskType, get_peft_model from torch import Tensor, nn +from torch.profiler import record_function +from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration +from lerobot.common.policies.act.modeling_act import ACTDecoder from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.vla.configuration_vla import VLAConfig -from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration -from peft import get_peft_model, LoraConfig, TaskType + class VLAPolicy( nn.Module, @@ -51,22 +53,8 @@ def __init__( self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) - # Configure LoRA settings - ''' - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) - r=8, # The rank of the low-rank adaptation - lora_alpha=32, # Scaling factor - lora_dropout=0.1, # Dropout applied to LoRA layers - target_modules=["q_proj", "v_proj"] # The attention components where LoRA is applied - ) - ''' - #self.lora_config = lora_config - - #self.language_model = get_peft_model(self.language_model, lora_config) - #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = VLA(config) - self.processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")# Updated Qwen2VL without loss and lm_head + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = VLA(config, device=self.device) self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.reset() @@ -78,299 +66,219 @@ def reset(self): @torch.no_grad() def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: self.eval() - - from torch.profiler import profile, record_function, ProfilerActivity with record_function("normalize_inputs"): batch = self.normalize_inputs(batch) - + if len(self.expected_image_keys) > 0: - batch = dict(batch) - #batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4).to(self.device + batch = dict(batch) batch["observation.images"] = [img for k in self.expected_image_keys for img in batch[k]] - - #torch.cat([batch[k] for k in self.expected_image_keys], dim=0).to(self.device) - batch["prompt"] = self.config.prompt - batch_size = len(batch["observation.images"]) - #batch = self.normalize_targets(batch) - with record_function("processor"): - processed_inputs = self.processor( - text=[batch["prompt"]]*batch_size, images=list(batch["observation.images"]), - return_tensors="pt", padding=True, do_rescale=False, - #image_mean = [0.485, 0.456, 0.406], - #image_std = [0.229, 0.224, 0.225] - ) - - with record_function("processed_inputs to cuda"): - for k,v in processed_inputs.items(): - processed_inputs[k] = processed_inputs[k].to(device=batch["observation.state"].device) # Forward pass through VLA with record_function("model"): - predicted_actions = self.model(processed_inputs, batch_size) - + predicted_actions = self.model(batch) + with record_function("unnormalize_outputs"): if len(self._action_queue) == 0: actions = self.unnormalize_outputs({"action": predicted_actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) - + return self._action_queue.popleft() def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - batch = self.normalize_inputs(batch) - if len(self.expected_image_keys) > 0: - batch = dict(batch) - #batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4).to(self.device + batch = dict(batch) batch["observation.images"] = [img for k in self.expected_image_keys for img in batch[k]] - - #torch.cat([batch[k] for k in self.expected_image_keys], dim=0).to(self.device) - batch["prompt"] = self.config.prompt - batch_size = len(batch["observation.images"]) - batch = self.normalize_targets(batch) - processed_inputs = self.processor( - text=[batch["prompt"]]*batch_size, images=list(batch["observation.images"]), - return_tensors="pt", padding=True, do_rescale=False, - #mage_mean = [0.485, 0.456, 0.406], - #image_std = [0.229, 0.224, 0.225] - ) - for k,v in processed_inputs.items(): - processed_inputs[k] = processed_inputs[k].to(device=batch["observation.state"].device) - ''' - # Pass inputs through Llava and VLA - llava_output = self.language_model( - **processed_inputs, - return_dict=True, - output_hidden_states=True - ) - breakpoint() - - last_hidden_state = llava_output.hidden_states[-1].to(dtype=torch.float16).to(self.device) - num_img_feats = 298 - seq_len = llava_output.image_hidden_states.shape[0] // batch_size - - - image_features = llava_output.image_hidden_states.view(batch_size, seq_len, -1) - image_hidden_states = image_features[:,:num_img_feats, :].to(dtype=torch.float16).to(self.device) - final_features = torch.cat((image_hidden_states, last_hidden_state), dim=1).to(dtype=torch.float16).to(self.device) - - #hidden_states = hidden_states[:, -4:, :] - #hidden_states.to(dtype=torch.float16).to(self.device) - breakpoint() - # Forward pass through VLA - ''' - predicted_actions = self.model(processed_inputs, batch_size) - + predicted_actions = self.model(batch) + loss_dict = {} if "action" in batch: true_actions = batch["action"] - l1_loss = (F.l1_loss(predicted_actions, true_actions, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)).mean() + l1_loss = ( + F.l1_loss(predicted_actions, true_actions, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() loss_dict["l1_loss"] = l1_loss.item() loss_dict["loss"] = l1_loss return loss_dict -class ActionDecoderLayer(nn.Module): - def __init__(self, config: VLAConfig): - super().__init__() - self.self_attn = nn.MultiheadAttention( - config.hidden_size, config.num_attention_heads, dropout=config.attention_dropout - ) - self.cross_attn = nn.MultiheadAttention( - config.hidden_size, config.num_attention_heads, dropout=config.attention_dropout - ) - - # Feed forward layers. - self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.dropout = nn.Dropout(config.attention_dropout) - self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) - - self.norm1 = nn.LayerNorm(config.hidden_size) - self.norm2 = nn.LayerNorm(config.hidden_size) - self.norm3 = nn.LayerNorm(config.hidden_size) - self.dropout1 = nn.Dropout(config.attention_dropout) - self.dropout2 = nn.Dropout(config.attention_dropout) - self.dropout3 = nn.Dropout(config.attention_dropout) - - self.activation = nn.ReLU() - self.pre_norm = False # Assumed pre-norm architecture; can adjust based on config - - def maybe_add_pos_embed(self, tensor: torch.Tensor, pos_embed: torch.Tensor | None) -> torch.Tensor: - return tensor if pos_embed is None else tensor + pos_embed - - def forward( - self, - x: torch.Tensor, - encoder_out: torch.Tensor, - decoder_pos_embed: torch.Tensor | None = None, - encoder_pos_embed: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Args: - x: (Decoder Sequence, Batch, Hidden Size) tensor of input tokens. - encoder_out: (Encoder Sequence, Batch, Hidden Size) output features from the last layer of the encoder we are cross-attending with. - decoder_pos_embed: (Sequence, 1, Hidden Size) positional embedding for decoder queries. - encoder_pos_embed: (Sequence, 1, Hidden Size) positional embedding for encoder keys. - Returns: - (Sequence, Batch, Hidden Size) tensor of decoder output features. - """ - skip = x - if self.pre_norm: - x = self.norm1(x) - q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - - # Self-attention - x = self.self_attn(q, k, value=x)[0] # select just the output, not attention weights - x = skip + self.dropout1(x) - - if self.pre_norm: - skip = x - x = self.norm2(x) - else: - x = self.norm1(x) - skip = x - - # Cross-attention with encoder outputs - x = self.cross_attn( - query=self.maybe_add_pos_embed(x, decoder_pos_embed), - key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), - value=encoder_out, - )[0] # select just the output, not attention weights - x = skip + self.dropout2(x) - - if self.pre_norm: - skip = x - x = self.norm3(x) - else: - x = self.norm2(x) - skip = x - - # Feed-forward network - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = skip + self.dropout3(x) - - if not self.pre_norm: - x = self.norm3(x) - - return x - - -class ActionDecoder(nn.Module): - def __init__(self, config: VLAConfig): - """Runs multiple decoder layers followed by normalization.""" - super().__init__() - self.layers = nn.ModuleList([ActionDecoderLayer(config) for _ in range(config.num_decoder_layers)]) - self.norm = nn.LayerNorm(config.hidden_size) - - def forward( - self, - x: torch.Tensor, - encoder_out: torch.Tensor, - decoder_pos_embed: torch.Tensor | None = None, - encoder_pos_embed: torch.Tensor | None = None, - ) -> torch.Tensor: - for layer in self.layers: - x = layer( - x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed - ) - x = self.norm(x) - return x class VLA(nn.Module): - def __init__(self, config: VLAConfig): + def __init__(self, config: VLAConfig, device: torch.device = "cpu"): super().__init__() - - # Initialize the Qwen2VLForConditionalGeneration and ActionDecoder - #qwen2_vl_config = make_qwen2_vl_config(config) self.chunk_size = config.chunk_size - self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.hidden_size) - self.action_decoder = ActionDecoder(config) # Use the updated ActionDecoder - self.action_head = nn.Linear(config.hidden_size, config.output_shapes["action"][0]) - self.vision_language_model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", - #torch_dtype=torch.float16, - device_map = 'cuda') - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) - r=4, # The rank of the low-rank adaptation - lora_alpha=16, # Scaling factor - lora_dropout=0.1, # Dropout applied to LoRA layers - target_modules=["q_proj", "v_proj"] # The attention components where LoRA is applied - ) - self.lora_config = lora_config - for param in self.vision_language_model.parameters(): - param.requires_grad = False + self.action_decoder_name = config.action_decoder.get("name", "act") + if self.action_decoder_name == "act": + action_decoder_config = OmegaConf.create(config.action_decoder) + self.action_decoder = ACTDecoder(action_decoder_config) + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"]) + else: + raise NotImplementedError(f"{self.action_decoder_name} not supported.") + + self.action_head = nn.Linear(config.action_decoder["dim_model"], config.output_shapes["action"][0]) + self.vlm_backbone_name = config.vlm_backbone["name"] + self.vlm_backbone_feature_selection = config.vlm_backbone.get("feature_selection", "last_token") + if "llava-onevision" in self.vlm_backbone_name: + self.vision_language_model = LlavaOnevisionForConditionalGeneration.from_pretrained( + self.vlm_backbone_name, + device_map="auto", + torch_dtype=torch.float16, + # attn_implementation="flash_attention_2" + ) + self.processor = AutoProcessor.from_pretrained(self.vlm_backbone_name) + elif "paligemma" in self.vlm_backbone_name: + from transformers import PaliGemmaForConditionalGeneration + + self.vision_language_model = PaliGemmaForConditionalGeneration.from_pretrained( + self.vlm_backbone_name, + device_map="auto", + torch_dtype=torch.float16, + # attn_implementation="flash_attention_2" + ) + self.processor = AutoProcessor.from_pretrained(self.vlm_backbone_name) + else: + raise NotImplementedError(f"{self.vlm_backbone_name} not supported.") + self.use_prompt_template = config.use_prompt_template + self.num_img_tokens = config.num_img_tokens # e.g. 598 to match the number of hidden states in ACT + + self.peft_method = config.peft_method + if "lora" in self.peft_method: + peft_config = config.peft_config + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) + r=peft_config["r"], # The rank of the low-rank adaptation + lora_alpha=peft_config["lora_alpha"], # Scaling factor + lora_dropout=peft_config["lora_dropout"], # Dropout applied to LoRA layers + target_modules=peft_config["target_modules"], # The components where LoRA is applied + ) + self.lora_config = lora_config + # Apply LoRA and ensure only LoRA parameters are trainable + self.vision_language_model = get_peft_model(self.vision_language_model, lora_config) + for name, param in self.vision_language_model.named_parameters(): + if ( + "lm_head" in name or "lora" in name + ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer + param.requires_grad = True + else: + param.requires_grad = False - # Apply LoRA and ensure only LoRA parameters are trainable - self.vision_language_model = get_peft_model(self.vision_language_model, lora_config) - for name, param in self.vision_language_model.named_parameters(): - if "lm_head" in name: # Adjust "lm_head" to the specific name of the head layer - param.requires_grad = True # Verify trainable parameters trainable_params = [] for name, param in self.vision_language_model.named_parameters(): if param.requires_grad: trainable_params.append(name) - print(f"Trainable parameter: {name}") + print(f"VLM trainable parameter: {name}") + + def apply_prompt_template(self, text: str, add_generation_prompt: bool = True) -> str: + if "llava-onevision" in self.vlm_backbone_name: + if self.use_prompt_template: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + {"type": "image"}, + ], + }, + ] + prompt = self.processor.apply_chat_template( + conversation, add_generation_prompt=add_generation_prompt + ) + else: + prompt = f"{text}" + elif "paligemma" in self.vlm_backbone_name: + prompt = f"{text}" + else: + prompt = text + + return prompt + + def get_vlm_features(self, processed_inputs) -> torch.Tensor: + vlm_output = self.vision_language_model( + **processed_inputs, return_dict=True, output_hidden_states=True + ) + + if any(k in self.vlm_backbone_name for k in ["llava-onevision", "paligemma"]): + batch_size = processed_inputs["input_ids"].shape[0] + last_hidden_state = vlm_output.hidden_states[-1] + seq_len = vlm_output.image_hidden_states.shape[0] // batch_size + image_features = vlm_output.image_hidden_states.view(batch_size, seq_len, -1) + + if self.vlm_backbone_feature_selection == "first_image": + hidden_states = image_features[:, : self.num_img_tokens, :] + elif self.vlm_backbone_feature_selection == "last_token": + hidden_states = last_hidden_state[:, -1:, :] + elif self.vlm_backbone_feature_selection == "all_generated": + hidden_states = last_hidden_state + elif self.vlm_backbone_feature_selection == "all": + hidden_states = torch.cat((image_features, last_hidden_state), dim=1) + elif self.vlm_backbone_feature_selection == "first_image_all": + hidden_states = torch.cat( + (image_features[:, : self.num_img_tokens, :], last_hidden_state), dim=1 + ) + else: + raise NotImplementedError(" not supportedd") + else: + raise NotImplementedError(f"{self.vlm_backbone_name} not implemented.") + + return hidden_states + + def get_action_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_states.shape + hidden_states = hidden_states.transpose(0, 1) + if self.action_decoder_name == "act": + # Generate positional embeddings for the decoder + decoder_pos_embeddings = self.decoder_pos_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) + # Decode the action with positional embeddings and encoder output + x = torch.zeros( + (self.chunk_size, batch_size, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + action_logits = self.action_decoder( + x=x, encoder_out=hidden_states, decoder_pos_embed=decoder_pos_embeddings + ) + # Final action logits through the action head + action_logits = self.action_head(action_logits) + action_logits = action_logits.transpose(0, 1) + else: + raise NotImplementedError(f"{self.action_decoder_name} not supported.") - #self.device = self.vision_language_model.device - #self.half() + return action_logits - def forward(self, processed_inputs, batch_size): + def forward(self, batch): """ - # Forward pass to compute action logits using hidden states from Qwen2VL (Llava). - + Forward pass to compute action logits. + Args: - hidden_states: Tensor of shape [batch_size, seq_len, hidden_size] from Llava model. - + batch: model input. + Returns: action_logits: Tensor of predicted actions. """ - llava_output = self.vision_language_model( - **processed_inputs, - return_dict=True, - output_hidden_states=True - ) - - - last_hidden_state = llava_output.hidden_states[-1]#.to(dtype=torch.float16).to(self.device) - num_img_feats = 598 - seq_len = llava_output.image_hidden_states.shape[0] // batch_size - - - image_features = llava_output.image_hidden_states.view(batch_size, seq_len, -1) - image_hidden_states = image_features[:,:num_img_feats, :]#.to(dtype=torch.float16).to(self.device) - hidden_states = torch.cat((image_hidden_states, last_hidden_state), dim=1)#.to(dtype=torch.float16).to(self.device) - - batch_size = hidden_states.size(0) # Ensure batch size is extracted - seq_len = hidden_states.size(1) # Sequence length of hidden states - hidden_size = hidden_states.size(2) # Hidden size - - # Ensure encoder_out has the correct shape [chunk_size, batch_size, seq_len, hidden_size] - # Repeat the encoder output for chunk size across the batch dimension - #encoder_out = hidden_states.unsqueeze(0).repeat(self.chunk_size, 1, 1, 1) # [chunk_size, batch_size, seq_len, hidden_size] - #encoder_out = encoder_out.view(self.chunk_size * seq_len, batch_size, hidden_size) - # - # Repeat the decoder input (hidden states) as well, maintaining batch and hidden size - #repeated_hidden_states = hidden_states.unsqueeze(0).repeat(self.chunk_size//seq_len, 1, 1, 1) # [chunk_size, batch_size, seq_len, hidden_size] - - #repeated_hidden_states = repeated_hidden_states.view(self.chunk_size, batch_size, hidden_size) - hidden_states = hidden_states.transpose(0,1) - # Generate positional embeddings for the decoder - decoder_pos_embeddings = self.decoder_pos_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) - - # Decode the action with positional embeddings and encoder output - x = torch.zeros((self.chunk_size, batch_size, hidden_size), dtype = hidden_states.dtype, device=hidden_states.device) - action_logits = self.action_decoder(x=x, - encoder_out=hidden_states, - decoder_pos_embed = decoder_pos_embeddings - ) - - # Final action logits through the action head - action_logits = self.action_head(action_logits) + prompt = self.apply_prompt_template(batch["prompt"], add_generation_prompt=True) + + batch_size = len(batch["observation.images"]) + with record_function("processor"): + processed_inputs = self.processor( + text=[prompt] * batch_size, + images=list(batch["observation.images"]), + return_tensors="pt", + padding=True, + do_rescale=False, + ) + + with record_function("processed_inputs to cuda"): + for k in processed_inputs.keys(): + processed_inputs[k] = processed_inputs[k].to(device=batch["observation.state"].device) + + hidden_states = self.get_vlm_features(processed_inputs) + + action_logits = self.get_action_logits(hidden_states) - action_logits = action_logits.transpose(0, 1) return action_logits diff --git a/lerobot/configs/policy/vla.yaml b/lerobot/configs/policy/vla.yaml index 331bd5a25..638add4cf 100644 --- a/lerobot/configs/policy/vla.yaml +++ b/lerobot/configs/policy/vla.yaml @@ -14,7 +14,7 @@ training: online_steps: 0 eval_freq: 20000 save_freq: 2000 - log_freq: 30 + log_freq: 100 save_checkpoint: true batch_size: 8 @@ -39,7 +39,7 @@ policy: n_obs_steps: 1 chunk_size: 100 n_action_steps: 100 - + input_shapes: observation.images.top: [3, 480, 640] # Video inputs (from video frames) observation.state: [14] # State input dimension @@ -63,29 +63,37 @@ policy: prompt: "Please transfer the cube" #"Please insert the tube into the socket." # Vision-Language Model (Qwen-VL) - vlm_backbone: qwen_vl - pretrained_backbone_weights: Qwen/Qwen-VL-2 # Pretrained weights for Qwen-VL + vlm_backbone: + name: llava-hf/llava-onevision-qwen2-0.5b-ov-hf + feature_selection: all_generated + use_prompt_template: true + num_img_tokens: 598 replace_final_stride_with_dilation: false - + # Combining state, video, and text combined_dim: 128 # Dimension after combining state and VLM outputs # Action decoder action_decoder: + name: act vlm_hidden_dim: 128 # Input dim for action decoder (combined dimension) action_dim: 4 # Output dimension (action space) + n_decoder_layers: 1 + n_encoder_layers: 4 + dim_model: 896 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + pre_norm: false + # Training and loss computation + dropout: 0.1 + temporal_ensemble_coeff: null - # Transformer layers for action decoding - dim_model: 512 - n_heads: 8 - dim_feedforward: 3200 - feedforward_activation: relu - n_encoder_layers: 4 - n_decoder_layers: 1 - latent_dim: 32 - - # Inference - temporal_ensemble_coeff: null - # Training and loss computation - dropout: 0.1 + # Lora config + peft_method: lora + peft_config: + r: 4 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "v_proj"] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index dcdd066eb..ce9a010d6 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -60,6 +60,7 @@ from huggingface_hub.errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from torch import Tensor, nn +from torch.profiler import ProfilerActivity, profile, record_function from tqdm import trange from lerobot.common.datasets.factory import make_dataset @@ -77,7 +78,6 @@ inside_slurm, set_global_seed, ) -from torch.profiler import profile, record_function, ProfilerActivity def rollout( @@ -138,7 +138,7 @@ def rollout( # Keep track of which environments are done. done = np.array([False] * env.num_envs) max_steps = env.call("_max_episode_steps")[0] - + def trace_handler(prof): prof.export_chrome_trace(f"outputs/trace_schedule_{prof.step_num}.json") @@ -149,7 +149,7 @@ def trace_handler(prof): warmup=1, active=3, ), - on_trace_ready=trace_handler + on_trace_ready=trace_handler, ) as prof: progbar = trange( max_steps, @@ -164,7 +164,6 @@ def trace_handler(prof): all_observations.append(deepcopy(observation)) observation = {key: observation[key].to(device, non_blocking=True) for key in observation} - with torch.inference_mode(): with record_function("select_action"): action = policy.select_action(observation) @@ -201,7 +200,7 @@ def trace_handler(prof): progbar.update() prof.step() - if step==5: + if step == 5: break # Track the final observation. @@ -287,7 +286,6 @@ def render_frame(env: gym.vector.VectorEnv): # we dont want progress bar when we use slurm, since it clutters the logs - progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) for batch_ix in progbar: # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout @@ -301,7 +299,7 @@ def render_frame(env: gym.vector.VectorEnv): seeds = range( start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) ) - + rollout_data = rollout( env, policy, diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2853093a7..7c7a4943a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -101,11 +101,8 @@ def make_optimizer_and_scheduler(cfg, policy): elif cfg.policy.name == "vla": # Optimizer for VLAPolicy with LoRA optimizer_params = [ - { - "params": [ - p for n, p in policy.model.named_parameters() if p.requires_grad - ], + "params": [p for n, p in policy.model.named_parameters() if p.requires_grad], "lr": cfg.training.lr, # Different LR for action decoder (if required) } ] @@ -148,7 +145,6 @@ def update_policy( error_if_nonfinite=False, ) - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, # although it still skips optimizer.step() if the gradients contain infs or NaNs. with lock if lock is not None else nullcontext(): @@ -671,6 +667,7 @@ def train_cli(cfg: dict): job_name=hydra.core.hydra_config.HydraConfig.get().job.name, ) + def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): from hydra import compose, initialize diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 4f0bd343e..658d6ba6c 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -250,7 +250,7 @@

if(!canPlayVideos){ this.videoCodecError = true; } - + // process CSV data this.videos = document.querySelectorAll('video'); this.video = this.videos[0];