Skip to content

Commit

Permalink
Fix pusht keypoints + add BackwardCompatibilityError for dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Nov 19, 2024
1 parent e1e7edb commit 04bbf71
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 63 deletions.
107 changes: 60 additions & 47 deletions examples/port_datasets/pusht_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,42 @@


def create_empty_dataset(repo_id, mode):
features = {}
features = {
"observation.state": {
"dtype": "float32",
"shape": (2,),
"names": [
["x", "y"],
],
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": [
["x", "y"],
],
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
}

if mode == "keypoints":
state_dim = 16
features["observation.environment_state"] = {
"dtype": "float32",
"shape": (16,),
"names": [
"keypoints",
],
}
else:
state_dim = 2
features["observation.image"] = {
"dtype": mode,
"shape": (3, 96, 96),
Expand All @@ -25,35 +55,6 @@ def create_empty_dataset(repo_id, mode):
],
}

features.update(
{
"observation.state": {
"dtype": "float32",
"shape": (state_dim,),
"names": [
["x", "y"],
],
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": [
["x", "y"],
],
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
}
)

dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=10,
Expand Down Expand Up @@ -146,7 +147,7 @@ def calculate_reward(coverage, success_threshold):
return np.clip(coverage / success_threshold, 0, 1)


def populate_dataset(dataset, episode_data_index, episodes, image, state, action, reward, success):
def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success):
if episodes is None:
episodes = range(len(episode_data_index["from"]))

Expand All @@ -160,20 +161,22 @@ def populate_dataset(dataset, episode_data_index, episodes, image, state, action

frame = {
"action": torch.from_numpy(action[i]),
"timestamp": frame_idx / dataset.fps,
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
}

frame["observation.state"] = torch.from_numpy(state[i])

if env_state is not None:
frame["observation.environment_state"] = torch.from_numpy(env_state[i])

if image is not None:
frame["observation.image"] = torch.from_numpy(image[i])

# TODO(rcadene): add_frame_to_buffer, add_episode_from_buffer
dataset.add_frame(frame)

dataset.add_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")

return dataset

Expand Down Expand Up @@ -205,7 +208,8 @@ def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
episode_data_index,
episodes,
image=None if mode == "keypoints" else image,
state=keypoints if mode == "keypoints" else agent_pos,
state=agent_pos,
env_state=keypoints if mode == "keypoints" else None,
action=action,
reward=reward,
success=success,
Expand All @@ -217,17 +221,26 @@ def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):


if __name__ == "__main__":
episodes = [0, 1]
# episodes = None

# for mode in ["video"]:
for mode in ["image"]:
# for mode in ["keypoints"]:
# for mode in ["video", "image", "keypoints"]:
repo_id = "cadene/pusht_v2"
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht"

episodes = None
# Uncomment if you want to try with a subset (episode 0 and 1)
# episodes = [0, 1]

modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode
# modes = ["video"]
# modes = ["image"]
# modes = ["keypoints"]

for mode in ["video", "image", "keypoints"]:
if mode in ["image", "keypoints"]:
repo_id += f"_{mode}"

# download and load raw dataset, create LeRobotDataset, populate it, push to hub
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)

# dataset = LeRobotDataset(repo_id="cadene/pusht_v2", local_files_only=True)
# dataset_old = LeRobotDataset(repo_id="lerobot/pusht")
# Uncomment if you want to loal the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# breakpoint()
2 changes: 1 addition & 1 deletion lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def add_frame(self, frame: dict) -> None:
self.episode_buffer = self._create_episode_buffer()

frame_index = self.episode_buffer["size"]
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)

Expand Down
45 changes: 32 additions & 13 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import textwrap
import warnings
from itertools import accumulate
from pathlib import Path
Expand Down Expand Up @@ -188,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
return int(split[0]), int(split[1])


class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
or open an issue on GitHub.
""")
super().__init__(message)


def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
Expand All @@ -209,18 +230,16 @@ def check_version_compatibility(
)


def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
def get_hub_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)

warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def convert_dataset(
arxiv: str | None = None,
test_branch: str | None = None,
):
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
v1 = get_hub_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/dataset_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _create_lerobot_dataset_metadata(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download

return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
Expand Down

0 comments on commit 04bbf71

Please sign in to comment.