Skip to content

Commit afbd42d

Browse files
committed
Add manage_dataset
1 parent 8e7d697 commit afbd42d

File tree

2 files changed

+118
-2
lines changed

2 files changed

+118
-2
lines changed

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,13 @@ def encode_episode_videos(self, episode_index: int) -> dict:
873873

874874
return video_paths
875875

876-
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
876+
def consolidate(
877+
self,
878+
run_compute_stats: bool = True,
879+
keep_image_files: bool = False,
880+
batch_size: int = 8,
881+
num_workers: int = 8,
882+
) -> None:
877883
self.hf_dataset = self.load_hf_dataset()
878884
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
879885
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@@ -896,7 +902,7 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F
896902
if run_compute_stats:
897903
self.stop_image_writer()
898904
# TODO(aliberts): refactor stats in save_episodes
899-
self.meta.stats = compute_stats(self)
905+
self.meta.stats = compute_stats(self, batch_size=batch_size, num_workers=num_workers)
900906
serialized_stats = serialize_dict(self.meta.stats)
901907
write_json(serialized_stats, self.root / STATS_PATH)
902908
self.consolidated = True

lerobot/scripts/manage_dataset.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Utilities to manage a dataset.
3+
4+
Examples of usage:
5+
6+
- Consolidate a dataset, by encoding images into videos and computing statistics:
7+
```bash
8+
python lerobot/scripts/manage_dataset.py consolidate \
9+
--repo-id $USER/koch_test
10+
```
11+
12+
- Consolidate a dataset which is not uploaded on the hub yet:
13+
```bash
14+
python lerobot/scripts/manage_dataset.py consolidate \
15+
--repo-id $USER/koch_test \
16+
--local-files-only 1
17+
```
18+
19+
- Upload a dataset on the hub:
20+
```bash
21+
python lerobot/scripts/manage_dataset.py push_to_hub \
22+
--repo-id $USER/koch_test
23+
```
24+
"""
25+
26+
import argparse
27+
from pathlib import Path
28+
29+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
30+
31+
if __name__ == "__main__":
32+
parser = argparse.ArgumentParser()
33+
subparsers = parser.add_subparsers(dest="mode", required=True)
34+
35+
# Set common options for all the subparsers
36+
base_parser = argparse.ArgumentParser(add_help=False)
37+
base_parser.add_argument(
38+
"--root",
39+
type=Path,
40+
default=None,
41+
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
42+
)
43+
base_parser.add_argument(
44+
"--repo-id",
45+
type=str,
46+
default="lerobot/test",
47+
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
48+
)
49+
base_parser.add_argument(
50+
"--local-files-only",
51+
type=int,
52+
default=0,
53+
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
54+
)
55+
56+
parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
57+
parser_conso.add_argument(
58+
"--batch-size",
59+
type=int,
60+
default=32,
61+
help="Batch size loaded by DataLoader for computing the dataset statistics.",
62+
)
63+
parser_conso.add_argument(
64+
"--num-workers",
65+
type=int,
66+
default=8,
67+
help="Number of processes of Dataloader for computing the dataset statistics.",
68+
)
69+
70+
parser_push = subparsers.add_parser("push_to_hub", parents=[base_parser])
71+
parser_push.add_argument(
72+
"--tags",
73+
type=str,
74+
nargs="*",
75+
default=None,
76+
help="Optional additional tags to categorize the dataset on the Hugging Face Hub. Use space-separated values (e.g. 'so100 indoor'). The tag 'LeRobot' will always be added.",
77+
)
78+
parser_push.add_argument(
79+
"--license",
80+
type=str,
81+
default="apache-2.0",
82+
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
83+
)
84+
parser_push.add_argument(
85+
"--private",
86+
type=int,
87+
default=0,
88+
help="Create a private dataset repository on the Hugging Face Hub.",
89+
)
90+
91+
args = parser.parse_args()
92+
kwargs = vars(args)
93+
94+
mode = kwargs.pop("mode")
95+
repo_id = kwargs.pop("repo_id")
96+
root = kwargs.pop("root")
97+
local_files_only = kwargs.pop("local_files_only")
98+
99+
dataset = LeRobotDataset(
100+
repo_id=repo_id,
101+
root=root,
102+
local_files_only=local_files_only,
103+
)
104+
105+
if mode == "consolidate":
106+
dataset.consolidate(**kwargs)
107+
108+
elif mode == "push_to_hub":
109+
private = kwargs.pop("private") == 1
110+
dataset.push_to_hub(private=private, **kwargs)

0 commit comments

Comments
 (0)