Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset v2.0 #461

Merged
merged 130 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
ad115b6
WIP
aliberts Oct 3, 2024
17a1214
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 4, 2024
1016a98
Add upload folders
aliberts Oct 4, 2024
07e113c
Add info.json link
aliberts Oct 4, 2024
028c17f
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 4, 2024
21ba4b5
Add pixel channels
aliberts Oct 6, 2024
2d75b93
Update info.json format
aliberts Oct 8, 2024
096824b
Rework LeRobotDataset.__init__
aliberts Oct 9, 2024
3113038
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 9, 2024
b417ceb
Update LeRobotDataset.__get_item__
aliberts Oct 10, 2024
6d2bc11
Add doc, scrap video_frame_keys attribute
aliberts Oct 11, 2024
7f68088
Add huggingface-hub patch for offline snapshot_download with local_dir
aliberts Oct 11, 2024
3ea5312
Add padding keys and download_data option
aliberts Oct 11, 2024
8bd406e
Add suggestions from code review
aliberts Oct 11, 2024
cf63334
Add multitask support, refactor conversion script
aliberts Oct 13, 2024
cbc51e1
Extend v1 compatibility
aliberts Oct 14, 2024
f96773d
Fix safe_version
aliberts Oct 14, 2024
835ab5a
Cleanup, fix load_tasks
aliberts Oct 15, 2024
da78bbf
Update load_tasks doc
aliberts Oct 15, 2024
9433ac5
WIP add batch convert
aliberts Oct 15, 2024
1102640
Add fixes for batch convert
aliberts Oct 15, 2024
c146ba9
Add episode chunks logic, move_videos & lfs tracking fix
aliberts Oct 16, 2024
50a75ad
Write episodes as jsonlines
aliberts Oct 17, 2024
ad3f112
Add fixes for lfs tracking
aliberts Oct 17, 2024
3ee3739
Add batch conversion log
aliberts Oct 17, 2024
7242c57
Cleanup
aliberts Oct 17, 2024
d0d8193
Add unitreeh and umi
aliberts Oct 17, 2024
354f37a
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 17, 2024
be64d54
Update doc
aliberts Oct 17, 2024
beacb7e
Cleanup
aliberts Oct 17, 2024
3a9f964
Add copyrights
aliberts Oct 17, 2024
91e8ce7
Remove caret requirement
aliberts Oct 18, 2024
e7355ba
Fix episodes.jsonl
aliberts Oct 18, 2024
1a51505
Add download_metadata, move default paths
aliberts Oct 18, 2024
bce3dc3
Add load_metadata
aliberts Oct 18, 2024
ac3798b
Move default paths, use jsonlines for tasks
aliberts Oct 18, 2024
9316cf4
Add file paths
aliberts Oct 20, 2024
e46bdb9
Change card creation
aliberts Oct 20, 2024
3b925c3
Add ImageWriter
aliberts Oct 20, 2024
c1232a0
Add add_frame, empty dataset creation
aliberts Oct 20, 2024
9ebf8b8
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 20, 2024
299451a
Add add_episode & task logic
aliberts Oct 21, 2024
c4c0a43
add delete_episode, WIP on consolidate
aliberts Oct 21, 2024
e991a31
Improve consistency between __init__() and create(), WIP on consolidate
aliberts Oct 21, 2024
a805458
Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
aliberts Oct 22, 2024
ee52b8b
Add channels to intelrealsense
aliberts Oct 22, 2024
b46db7e
Fix tests
aliberts Oct 22, 2024
6c2cb6e
Remove populate dataset
aliberts Oct 22, 2024
237a484
Fix paths & add add_frame doc
aliberts Oct 22, 2024
c72dc23
Remove total_episodes from default parquet path
aliberts Oct 22, 2024
c3c0141
Update & fix conversion script
aliberts Oct 22, 2024
9dca233
Fix episode chunk
aliberts Oct 22, 2024
fb73cdb
Update dataset doc
aliberts Oct 22, 2024
a2a8538
add write_stats, changes names, add some typing
aliberts Oct 23, 2024
7ae8d05
Fix visualization
aliberts Oct 23, 2024
b8bdbc1
Fix check_delta_timestamps
aliberts Oct 23, 2024
07570f8
Fix _query_videos return shapes
aliberts Oct 23, 2024
1aba80d
Fix consolidate
aliberts Oct 23, 2024
0098bd2
Nits
aliberts Oct 23, 2024
0d77be9
Move ImageWriter creation inside the dataset
aliberts Oct 23, 2024
60865e8
Allow dataset creation without robot
aliberts Oct 23, 2024
450eae3
Add error msg
aliberts Oct 23, 2024
615894d
Add test_same_attributes_defined
aliberts Oct 24, 2024
8bcf81f
Add todo
aliberts Oct 24, 2024
18ffa42
Add json/jsonl io functions
aliberts Oct 24, 2024
e210d79
Add video_info, fix image_writer
aliberts Oct 25, 2024
df3d2ec
Speedup test
aliberts Oct 26, 2024
51e87f6
Fix image writer
aliberts Oct 28, 2024
4c22de2
Add sanity check
aliberts Oct 28, 2024
8d57093
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Oct 29, 2024
fee5fa5
Remove image_writer arg
aliberts Oct 29, 2024
ee51f54
Remove dataset from image_transform tests
aliberts Oct 29, 2024
ff84024
Add dataset fixtures
aliberts Oct 31, 2024
e69f0c5
Add test_delta_timestamps.py
aliberts Oct 31, 2024
e1845d4
Update doc
aliberts Oct 31, 2024
c70b8d0
Update doc
aliberts Oct 31, 2024
1267c3e
Split fixtures into factories and files
aliberts Oct 31, 2024
ab23a4f
Add fixtures in test_datasets
aliberts Oct 31, 2024
443a9ee
Remove/comment obsolete tests
aliberts Oct 31, 2024
5ea7c78
Remove obsolete code
aliberts Oct 31, 2024
cd1509d
Mock snapshot_download
aliberts Nov 1, 2024
2650872
Add tasks and episodes factories
aliberts Nov 1, 2024
79d114c
Rename num_samples -> num_frames for consistency
aliberts Nov 1, 2024
293bdc7
Simplify, add test content, add todo
aliberts Nov 1, 2024
375abd3
Add img and img_tensor factories
aliberts Nov 2, 2024
6b2ec1e
Add test_image_writer, accept PIL images, improve ImageWriter perf in…
aliberts Nov 2, 2024
7a342db
Add more options to img factories
aliberts Nov 2, 2024
df2cb51
Add todo in skipped test
aliberts Nov 2, 2024
ac79e8c
Fix test_online_buffer.py
aliberts Nov 3, 2024
e4ba084
Add LeRobotDatasetMetadata
aliberts Nov 3, 2024
16103cb
Fix hanging
aliberts Nov 3, 2024
95a4b59
Fix vizualize
aliberts Nov 3, 2024
c2d6fb6
Fix werkzeug alert
aliberts Nov 3, 2024
f6c90ca
Remove end-to-end tests
aliberts Nov 3, 2024
56e4603
Deactivate policies backward compatibility test
aliberts Nov 3, 2024
fde29e0
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 3, 2024
a6762ec
Fix advanced example 2
aliberts Nov 3, 2024
74270c8
Remove reset_episode_index
aliberts Nov 3, 2024
7b159a6
Move calculate_episode_data_index
aliberts Nov 3, 2024
b69a132
Fix test_examples
aliberts Nov 3, 2024
757ea17
Fix test_examples
aliberts Nov 3, 2024
aed9f40
Refactor dataset features
aliberts Nov 5, 2024
f3630ad
Fix tests
aliberts Nov 5, 2024
4d15861
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 5, 2024
a91b7c6
Add extra info to dataset card, various fixes from Remi's review
aliberts Nov 18, 2024
8546358
Fix test_visualize_dataset_html
aliberts Nov 18, 2024
eda02fa
Skip test_visualize_local_dataset
aliberts Nov 18, 2024
c72ad49
Skip test_examples
aliberts Nov 18, 2024
acae4b4
Add comment on license
aliberts Nov 18, 2024
1f13bda
Improve dataset v2 (#498)
Cadene Nov 19, 2024
6203641
Use HWC for images
aliberts Nov 19, 2024
9ee8711
Update example 1
aliberts Nov 19, 2024
f43e5d0
Fix tests
aliberts Nov 19, 2024
c6ad495
Enhance dataset cards
aliberts Nov 20, 2024
37da50b
Fix conversion script
aliberts Nov 20, 2024
93d9bf8
Add open X datasets
aliberts Nov 20, 2024
36b9b60
Update example 1
aliberts Nov 20, 2024
f56d769
Remove todos
aliberts Nov 20, 2024
23f6c87
Apply suggestions from code review
aliberts Nov 25, 2024
3b5af7e
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 25, 2024
6ad84a6
Refactor pusht_zarr
aliberts Nov 25, 2024
49bdcc0
Remove comment
aliberts Nov 26, 2024
56c01a2
Activate end-to-end tests
aliberts Nov 26, 2024
691d39a
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09…
aliberts Nov 26, 2024
2945dca
Comment
aliberts Nov 26, 2024
2556960
Remove default root data dir, add fixes
aliberts Nov 28, 2024
d6b4429
Remove DATA_DIR references
aliberts Nov 28, 2024
82ff776
Remove remaining DATA_DIR reference
aliberts Nov 28, 2024
ea5009e
Remove commented code
aliberts Nov 28, 2024
0cb0af0
Remove unused arg
aliberts Nov 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
539 changes: 539 additions & 0 deletions convert_dataset_16_to_20.py

Large diffs are not rendered by default.

337 changes: 262 additions & 75 deletions lerobot/common/datasets/lerobot_dataset.py

Large diffs are not rendered by default.

233 changes: 157 additions & 76 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import warnings
from functools import cache
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict

import datasets
import torch
from datasets import load_dataset, load_from_disk
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms

DATASET_CARD_TEMPLATE = """
Expand Down Expand Up @@ -96,7 +96,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):


@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
def get_hub_safe_version(repo_id: str, version: str) -> str:
num_version = float(version.strip("v"))
if num_version < 2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
)
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
Expand All @@ -116,104 +123,178 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
return version


def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
def load_hf_dataset(
local_dir: Path,
data_path: str,
total_episodes: int,
episodes: list[int] | None = None,
split: str = "train",
) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
if episodes is None:
path = str(local_dir / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
files = [str(local_dir / fpath) for fpath in files]
hf_dataset = load_dataset("parquet", data_files=files, split=split)

hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset


def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode

Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)

return load_file(path)


def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std

Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)

stats = load_file(path)
fpath = hf_hub_download(
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
stats = json.load(f)

stats = flatten_dict(stats)
stats = {key: torch.tensor(value) for key, value in stats.items()}
return unflatten_dict(stats)


def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
"""info contains structural information about the dataset. It should be the reference and
act as the 'source of thruth' for what's inside the dataset.

Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
fpath = hf_hub_download(
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)

with open(path) as f:
info = json.load(f)
return info

def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
"""tasks contains all the tasks of the dataset, indexed by their task_index.

def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
Example:
```json
{
"0": "Pick the Lego block and drop it in the box on the right."
}
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)


def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}

cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}

return path

def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
# timestamps[2] += tolerance_s # TODO delete
# timestamps[-2] += tolerance_s/2 # TODO delete
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s

# We mask differences between the timestamp at the end of an episode
# and the one the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]

if not torch.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])

outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
}
outside_tolerances.append(entry)

if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
\n{pformat(outside_tolerances)}"""
)
return False

return True


def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
abs_delta_ts = torch.abs(torch.tensor(delta_ts))
within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s
if not torch.all(within_tolerance):
outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance]

if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False

return True


def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()

return delta_indices


def load_previous_and_future_frames(
Expand Down
39 changes: 1 addition & 38 deletions lerobot/common/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,8 @@
from datasets.features.features import register_feature


def load_from_videos(
item: dict[str, torch.Tensor],
video_frame_keys: list[str],
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
This probably happens because a memory reference to the video loader is created in the main process and a
subprocess fails to access it.
"""
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent

for key in video_frame_keys:
if isinstance(item[key], list):
# load multiple frames at once (expected when delta_timestamps is not None)
timestamps = [frame["timestamp"] for frame in item[key]]
paths = [frame["path"] for frame in item[key]]
if len(set(paths)) > 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]

frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]

frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames[0]

return item


def decode_video_frames_torchvision(
video_path: str,
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
Expand Down
2 changes: 1 addition & 1 deletion lerobot/scripts/push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def push_dataset_to_hub(
episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys:
for key in lerobot_dataset.camera_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)

Expand Down
3 changes: 1 addition & 2 deletions lerobot/scripts/visualize_dataset_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys
]


Expand Down
Loading
Loading