Skip to content

Commit

Permalink
Refactoring and support PaliGemma (#557)
Browse files Browse the repository at this point in the history
Co-authored-by: mshukor <[email protected]>
Co-authored-by: mshukor <[email protected]>
Co-authored-by: mshukor <[email protected]>
Co-authored-by: mshukor <[email protected]>
  • Loading branch information
5 people authored Dec 9, 2024
1 parent 4fe2f8e commit 8285265
Show file tree
Hide file tree
Showing 15 changed files with 259 additions and 1,483 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ dmypy.json

# Cython debug symbols
cython_debug/

# launcing examples with slurm
slurm/
1 change: 0 additions & 1 deletion lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __getitem__(self, idx):
self.delta_timestamps,
self.tolerance_s,
)

if self.video:
item = load_from_videos(
item,
Expand Down
65 changes: 0 additions & 65 deletions lerobot/common/datasets/mock.py

This file was deleted.

1 change: 0 additions & 1 deletion lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datas
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)

hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset

Expand Down
3 changes: 2 additions & 1 deletion lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import math
from collections import deque
from itertools import chain
from typing import Callable
from typing import Any, Callable

import einops
import numpy as np
Expand Down Expand Up @@ -584,6 +584,7 @@ def forward(
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
**kwargs: Any,
) -> Tensor:
for layer in self.layers:
x = layer(
Expand Down
197 changes: 0 additions & 197 deletions lerobot/common/policies/vla/configuration_qwen2_vl.py

This file was deleted.

61 changes: 44 additions & 17 deletions lerobot/common/policies/vla/configuration_vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,35 @@ class VLAConfig:
)

prompt: str = "Please transfer the cube."
#"Please insert the tube into the socket."
# "Please insert the tube into the socket."

# Architecture.

# Action decoder

action_decoder: dict = field(
default_factory=lambda: {
"name": "act",
"vlm_hidden_dim": 128,
"action_dim": 4,
"n_decoder_layers": 1,
"n_encoder_layers": 4,
"dim_model": 896,
"n_heads": 8,
"dim_feedforward": 3200,
"feedforward_activation": "relu",
"pre_norm": False,
"dropout": 0.1,
"temporal_ensemble_coeff": None,
}
)

# Language + Main transformer
vocab_size: int = 150528
hidden_size: int = 896
num_decoder_layers: int = 1
num_attention_heads: int = 8
intermediate_size: int = 3200
# n_decoder_layers: int = 1
# n_heads: int = 8
# dim_feedforward: int = 3200
hidden_act: str = "silu"
pad_token_id: int = 0
initializer_range: float = 0.02
Expand All @@ -136,7 +155,7 @@ class VLAConfig:
use_sliding_window: bool = False
sliding_window = 4096
max_window_layers = 80
attention_dropout = 0.0
# dropout = 0.0
rope_scaling: dict = field(
default_factory=lambda: {
"type": "mrope",
Expand All @@ -145,18 +164,6 @@ class VLAConfig:
)
pruned_heads = None

# Vision encoder
# depth: int = 32
# embed_dim: int = 1280
# hidden_size: int = 3584
# hidden_act: str = "quick_gelu"
# mlp_ratio: int = 4
# num_heads: int = 16
# in_channels: int = 3
# patch_size: int = 14
# spatial_merge_size: int = 2
# temporal_patch_size: int = 2
# attn_implementation: str = "eager"
vision_config: dict = field(
default_factory=lambda: {
"depth": 32,
Expand All @@ -173,6 +180,26 @@ class VLAConfig:
}
)

# Vision-Language Model (Qwen-VL)
vlm_backbone: dict = field(
default_factory=lambda: {
"name": "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"feature_selection": "first_image",
}
)
use_prompt_template: bool = True
num_img_tokens: int = 598

peft_method: str = "lora"
peft_config: dict = field(
default_factory=lambda: {
"r": 4,
"lora_alpha": 16,
"lora_dropout": 0.1,
"target_modules": ["q_proj", "v_proj"],
}
)

def __post_init__(self):
"""Input validation (not exhaustive)."""
# if not self.vision_backbone.startswith("resnet"):
Expand Down
Loading

0 comments on commit 8285265

Please sign in to comment.