Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset v2.0 #461

Merged
merged 130 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 118 commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
ad115b6
WIP
aliberts Oct 3, 2024
17a1214
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 4, 2024
1016a98
Add upload folders
aliberts Oct 4, 2024
07e113c
Add info.json link
aliberts Oct 4, 2024
028c17f
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 4, 2024
21ba4b5
Add pixel channels
aliberts Oct 6, 2024
2d75b93
Update info.json format
aliberts Oct 8, 2024
096824b
Rework LeRobotDataset.__init__
aliberts Oct 9, 2024
3113038
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 9, 2024
b417ceb
Update LeRobotDataset.__get_item__
aliberts Oct 10, 2024
6d2bc11
Add doc, scrap video_frame_keys attribute
aliberts Oct 11, 2024
7f68088
Add huggingface-hub patch for offline snapshot_download with local_dir
aliberts Oct 11, 2024
3ea5312
Add padding keys and download_data option
aliberts Oct 11, 2024
8bd406e
Add suggestions from code review
aliberts Oct 11, 2024
cf63334
Add multitask support, refactor conversion script
aliberts Oct 13, 2024
cbc51e1
Extend v1 compatibility
aliberts Oct 14, 2024
f96773d
Fix safe_version
aliberts Oct 14, 2024
835ab5a
Cleanup, fix load_tasks
aliberts Oct 15, 2024
da78bbf
Update load_tasks doc
aliberts Oct 15, 2024
9433ac5
WIP add batch convert
aliberts Oct 15, 2024
1102640
Add fixes for batch convert
aliberts Oct 15, 2024
c146ba9
Add episode chunks logic, move_videos & lfs tracking fix
aliberts Oct 16, 2024
50a75ad
Write episodes as jsonlines
aliberts Oct 17, 2024
ad3f112
Add fixes for lfs tracking
aliberts Oct 17, 2024
3ee3739
Add batch conversion log
aliberts Oct 17, 2024
7242c57
Cleanup
aliberts Oct 17, 2024
d0d8193
Add unitreeh and umi
aliberts Oct 17, 2024
354f37a
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 17, 2024
be64d54
Update doc
aliberts Oct 17, 2024
beacb7e
Cleanup
aliberts Oct 17, 2024
3a9f964
Add copyrights
aliberts Oct 17, 2024
91e8ce7
Remove caret requirement
aliberts Oct 18, 2024
e7355ba
Fix episodes.jsonl
aliberts Oct 18, 2024
1a51505
Add download_metadata, move default paths
aliberts Oct 18, 2024
bce3dc3
Add load_metadata
aliberts Oct 18, 2024
ac3798b
Move default paths, use jsonlines for tasks
aliberts Oct 18, 2024
9316cf4
Add file paths
aliberts Oct 20, 2024
e46bdb9
Change card creation
aliberts Oct 20, 2024
3b925c3
Add ImageWriter
aliberts Oct 20, 2024
c1232a0
Add add_frame, empty dataset creation
aliberts Oct 20, 2024
9ebf8b8
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 20, 2024
299451a
Add add_episode & task logic
aliberts Oct 21, 2024
c4c0a43
add delete_episode, WIP on consolidate
aliberts Oct 21, 2024
e991a31
Improve consistency between __init__() and create(), WIP on consolidate
aliberts Oct 21, 2024
a805458
Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
aliberts Oct 22, 2024
ee52b8b
Add channels to intelrealsense
aliberts Oct 22, 2024
b46db7e
Fix tests
aliberts Oct 22, 2024
6c2cb6e
Remove populate dataset
aliberts Oct 22, 2024
237a484
Fix paths & add add_frame doc
aliberts Oct 22, 2024
c72dc23
Remove total_episodes from default parquet path
aliberts Oct 22, 2024
c3c0141
Update & fix conversion script
aliberts Oct 22, 2024
9dca233
Fix episode chunk
aliberts Oct 22, 2024
fb73cdb
Update dataset doc
aliberts Oct 22, 2024
a2a8538
add write_stats, changes names, add some typing
aliberts Oct 23, 2024
7ae8d05
Fix visualization
aliberts Oct 23, 2024
b8bdbc1
Fix check_delta_timestamps
aliberts Oct 23, 2024
07570f8
Fix _query_videos return shapes
aliberts Oct 23, 2024
1aba80d
Fix consolidate
aliberts Oct 23, 2024
0098bd2
Nits
aliberts Oct 23, 2024
0d77be9
Move ImageWriter creation inside the dataset
aliberts Oct 23, 2024
60865e8
Allow dataset creation without robot
aliberts Oct 23, 2024
450eae3
Add error msg
aliberts Oct 23, 2024
615894d
Add test_same_attributes_defined
aliberts Oct 24, 2024
8bcf81f
Add todo
aliberts Oct 24, 2024
18ffa42
Add json/jsonl io functions
aliberts Oct 24, 2024
e210d79
Add video_info, fix image_writer
aliberts Oct 25, 2024
df3d2ec
Speedup test
aliberts Oct 26, 2024
51e87f6
Fix image writer
aliberts Oct 28, 2024
4c22de2
Add sanity check
aliberts Oct 28, 2024
8d57093
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 29, 2024
fee5fa5
Remove image_writer arg
aliberts Oct 29, 2024
ee51f54
Remove dataset from image_transform tests
aliberts Oct 29, 2024
ff84024
Add dataset fixtures
aliberts Oct 31, 2024
e69f0c5
Add test_delta_timestamps.py
aliberts Oct 31, 2024
e1845d4
Update doc
aliberts Oct 31, 2024
c70b8d0
Update doc
aliberts Oct 31, 2024
1267c3e
Split fixtures into factories and files
aliberts Oct 31, 2024
ab23a4f
Add fixtures in test_datasets
aliberts Oct 31, 2024
443a9ee
Remove/comment obsolete tests
aliberts Oct 31, 2024
5ea7c78
Remove obsolete code
aliberts Oct 31, 2024
cd1509d
Mock snapshot_download
aliberts Nov 1, 2024
2650872
Add tasks and episodes factories
aliberts Nov 1, 2024
79d114c
Rename num_samples -> num_frames for consistency
aliberts Nov 1, 2024
293bdc7
Simplify, add test content, add todo
aliberts Nov 1, 2024
375abd3
Add img and img_tensor factories
aliberts Nov 2, 2024
6b2ec1e
Add test_image_writer, accept PIL images, improve ImageWriter perf in…
aliberts Nov 2, 2024
7a342db
Add more options to img factories
aliberts Nov 2, 2024
df2cb51
Add todo in skipped test
aliberts Nov 2, 2024
ac79e8c
Fix test_online_buffer.py
aliberts Nov 3, 2024
e4ba084
Add LeRobotDatasetMetadata
aliberts Nov 3, 2024
16103cb
Fix hanging
aliberts Nov 3, 2024
95a4b59
Fix vizualize
aliberts Nov 3, 2024
c2d6fb6
Fix werkzeug alert
aliberts Nov 3, 2024
f6c90ca
Remove end-to-end tests
aliberts Nov 3, 2024
56e4603
Deactivate policies backward compatibility test
aliberts Nov 3, 2024
fde29e0
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 3, 2024
a6762ec
Fix advanced example 2
aliberts Nov 3, 2024
74270c8
Remove reset_episode_index
aliberts Nov 3, 2024
7b159a6
Move calculate_episode_data_index
aliberts Nov 3, 2024
b69a132
Fix test_examples
aliberts Nov 3, 2024
757ea17
Fix test_examples
aliberts Nov 3, 2024
aed9f40
Refactor dataset features
aliberts Nov 5, 2024
f3630ad
Fix tests
aliberts Nov 5, 2024
4d15861
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 5, 2024
a91b7c6
Add extra info to dataset card, various fixes from Remi's review
aliberts Nov 18, 2024
8546358
Fix test_visualize_dataset_html
aliberts Nov 18, 2024
eda02fa
Skip test_visualize_local_dataset
aliberts Nov 18, 2024
c72ad49
Skip test_examples
aliberts Nov 18, 2024
acae4b4
Add comment on license
aliberts Nov 18, 2024
1f13bda
Improve dataset v2 (#498)
Cadene Nov 19, 2024
6203641
Use HWC for images
aliberts Nov 19, 2024
9ee8711
Update example 1
aliberts Nov 19, 2024
f43e5d0
Fix tests
aliberts Nov 19, 2024
c6ad495
Enhance dataset cards
aliberts Nov 20, 2024
37da50b
Fix conversion script
aliberts Nov 20, 2024
93d9bf8
Add open X datasets
aliberts Nov 20, 2024
36b9b60
Update example 1
aliberts Nov 20, 2024
f56d769
Remove todos
aliberts Nov 20, 2024
23f6c87
Apply suggestions from code review
aliberts Nov 25, 2024
3b5af7e
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 25, 2024
6ad84a6
Refactor pusht_zarr
aliberts Nov 25, 2024
49bdcc0
Remove comment
aliberts Nov 26, 2024
56c01a2
Activate end-to-end tests
aliberts Nov 26, 2024
691d39a
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 26, 2024
2945dca
Comment
aliberts Nov 26, 2024
2556960
Remove default root data dir, add fixes
aliberts Nov 28, 2024
d6b4429
Remove DATA_DIR references
aliberts Nov 28, 2024
82ff776
Remove remaining DATA_DIR reference
aliberts Nov 28, 2024
ea5009e
Remove commented code
aliberts Nov 28, 2024
0cb0af0
Remove unused arg
aliberts Nov 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,40 @@ jobs:
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs


end-to-end:
name: End-to-end
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled

- name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio
run: |
sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev

- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "poetry"

- name: Install poetry dependencies
run: |
poetry install --all-extras

- name: Test end-to-end
run: |
make test-end-to-end \
&& rm -rf outputs
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
# end-to-end:
# name: End-to-end
aliberts marked this conversation as resolved.
Show resolved Hide resolved
# runs-on: ubuntu-latest
# env:
# DATA_DIR: tests/data
# MUJOCO_GL: egl
# steps:
# - uses: actions/checkout@v4
# with:
# lfs: true # Ensure LFS files are pulled

# - name: Install apt dependencies
# # portaudio19-dev is needed to install pyaudio
# run: |
# sudo apt-get update && \
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev

# - name: Install poetry
# run: |
# pipx install poetry && poetry config virtualenvs.in-project true
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH

# - name: Set up Python 3.10
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "poetry"

# - name: Install poetry dependencies
# run: |
# poetry install --all-extras

# - name: Test end-to-end
# run: |
# make test-end-to-end \
# && rm -rf outputs
2 changes: 1 addition & 1 deletion benchmarks/video/run_video_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def benchmark_encoding_decoding(
)

ep_num_images = dataset.episode_data_index["to"][0].item()
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)
Expand Down
123 changes: 83 additions & 40 deletions examples/1_load_lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,120 @@
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.

Features included in this script:
- Loading a dataset and accessing its properties.
- Filtering data by episode number.
- Converting tensor data for visualization.
- Saving video files from dataset frames.
- Viewing a dataset's metadata and exploring its properties.
- Loading an existing dataset from the hub or a subset of it.
- Accessing frames by episode number.
- Using advanced dataset features like timestamp-based frame selection.
- Demonstrating compatibility with PyTorch DataLoader for batch processing.

The script ends with examples of how to batch process data using PyTorch's DataLoader.
"""

from pathlib import Path
from pprint import pprint

import imageio
import torch
from huggingface_hub import HfApi

import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata

# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
pprint(lerobot.available_datasets)

# Let's take one for this example
repo_id = "lerobot/pusht"

# You can easily load a dataset from a Hugging Face repository
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
pprint(repo_ids)

# Or simply explore them in your web browser directly at:
# https://huggingface.co/datasets?other=LeRobot

# Let's take this one for this example
repo_id = "lerobot/aloha_mobile_cabinet"
# We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id)

# By instantiating just this class, you can quickly access useful information about the content and the
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")

print("Tasks:")
print(ds_meta.tasks)
print("Features:")
pprint(ds_meta.features)

# You can also get a short summary by simply printing the object:
print(ds_meta)

# You can then load the actual dataset from the hub.
# Either load any subset of episodes:
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])

# And see how many frames you have:
print(f"Selected episodes: {dataset.episodes}")
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")

# Or simply load the entire dataset:
dataset = LeRobotDataset(repo_id)
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")

# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets/index for more information).
print(dataset)
print(dataset.hf_dataset)
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)

# And provides additional utilities for robotics and compatibility with Pytorch
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(dataset.hf_dataset)

# Access frame indexes associated to first episode
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __get_item__ iterates over the frames of the dataset. Since our datasets are also structured by
aliberts marked this conversation as resolved.
Show resolved Hide resolved
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# frame indices associated to the first episode:
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()

# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]

# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# The objects returned by the dataset are all torch.Tensors
print(type(frames[0]))
print(frames[0].shape)
aliberts marked this conversation as resolved.
Show resolved Hide resolved

# Finally, we save the frames to a mp4 video for visualization.
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
# We can compare this shape with the information available for that feature
pprint(dataset.features[camera_key])
# In particular:
print(dataset.features[camera_key]["shape"])
# The shape is in (h, w, c) which is a more universal format.

# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
camera_key: [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.

dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)

# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
Expand All @@ -84,8 +126,9 @@
batch_size=32,
shuffle=True,
)

for batch in dataloader:
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
print(f"{batch['observation.state'].shape=}") # (32,8,c)
print(f"{batch['action'].shape=}") # (32,64,c)
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break
2 changes: 1 addition & 1 deletion examples/3_train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)

Expand Down
4 changes: 2 additions & 2 deletions examples/6_add_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
first_idx = dataset.episode_data_index["from"][0].item()

# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
frame = dataset[first_idx][dataset.meta.camera_keys[0]]


# Define the transformations
Expand All @@ -36,7 +36,7 @@
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)

# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]

# Create a directory to store output images
output_dir = Path("outputs/image_transforms")
Expand Down
37 changes: 16 additions & 21 deletions examples/advanced/2_calculate_validation_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
on the target environment, whether that be in simulation or the real world.
"""

# TODO(aliberts, rcadene): Update this script with the new v2 api
aliberts marked this conversation as resolved.
Show resolved Hide resolved
import math
from pathlib import Path

import torch
from huggingface_hub import snapshot_download

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

device = torch.device("cuda")
Expand All @@ -41,26 +42,20 @@
}

# Load the last 10% of episodes of the dataset as a validation set.
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
# - Load dataset metadata
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
# - Calculate train and val episodes
total_episodes = dataset_metadata.total_episodes
episodes = list(range(dataset_metadata.total_episodes))
num_train_episodes = math.floor(total_episodes * 90 / 100)
train_episodes = episodes[:num_train_episodes]
val_episodes = episodes[num_train_episodes:]
print(f"Number of episodes in full dataset: {total_episodes}")
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# - Load train an val datasets
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")

Expand Down
Loading
Loading