Skip to content

Commit

Permalink
fix sanity_check_dataset_robot_compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Nov 18, 2024
1 parent 7ba6318 commit e1e7edb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 24 deletions.
23 changes: 8 additions & 15 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,6 @@ def write_video_info(self) -> None:
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
# TODO(rcadene): What should we do here?
if "videos" not in self.info:
self.info["videos"] = {}

for key in self.video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
Expand All @@ -278,8 +274,6 @@ def create(
robot_type: str | None = None,
features: dict | None = None,
use_videos: bool = True,
# tags: list[str] | None = None,
# license_type: str | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
Expand All @@ -301,14 +295,11 @@ def create(
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}

# TODO(rcadene): implement sanity check for features

obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
# obj.tags = tags
# obj.license_type = license_type
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
Expand Down Expand Up @@ -439,6 +430,7 @@ def __init__(

# Unused attributes
self.image_writer = None
self.episode_buffer = None

self.root.mkdir(exist_ok=True, parents=True)

Expand All @@ -464,9 +456,6 @@ def __init__(
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None

# Create an empty buffer to extend the dataset if required
self.episode_buffer = self._create_episode_buffer()

def push_to_hub(
self,
tags: list | None = None,
Expand Down Expand Up @@ -704,9 +693,12 @@ def add_frame(self, frame: dict) -> None:
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
# TODO(rcadene): Add sanity check for the input, check it's numpy or torch,
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.

if self.episode_buffer is None:
self.episode_buffer = self._create_episode_buffer()

frame_index = self.episode_buffer["size"]
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
Expand Down Expand Up @@ -930,7 +922,8 @@ def create(
obj.tolerance_s = tolerance_s
obj.image_writer = None

obj.start_image_writer(image_writer_processes, image_writer_threads)
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)

# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
Expand Down
12 changes: 3 additions & 9 deletions lerobot/common/robot_devices/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
Expand Down Expand Up @@ -333,17 +334,10 @@ def sanity_check_dataset_name(repo_id, policy):


def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
# TODO(rcadene): fix that before merging
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) # noqa

fields = [
("robot_type", dataset.meta.info["robot_type"], robot_type),
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
("fps", dataset.meta.info["fps"], fps),
("keys", dataset.meta.info["keys"], keys),
("image_keys", dataset.meta.info["image_keys"], image_keys),
("video_keys", dataset.meta.info["video_keys"], video_keys),
("shapes", dataset.meta.info["shapes"], shapes),
("names", dataset.meta.info["names"], names),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]

mismatches = []
Expand Down

0 comments on commit e1e7edb

Please sign in to comment.