Skip to content

Commit

Permalink
fixed dataset format and added timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-aractingi committed Dec 2, 2024
1 parent 50922e1 commit 0478b15
Showing 1 changed file with 75 additions and 100 deletions.
175 changes: 75 additions & 100 deletions lerobot/scripts/control_sim_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@

import argparse
import importlib
import json
import logging
import time
import traceback
Expand All @@ -92,23 +91,16 @@
init_policy,
log_control_info,
predict_action,
stop_recording,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility
sanity_check_dataset_robot_compatibility,
stop_recording,
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say

DEFAULT_FEATURES = {
"action": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
Expand All @@ -120,7 +112,12 @@
"names": None,
},
"seed": {
"dtype": "int",
"dtype": "int64",
"shape": (1,),
"names": None,
},
"timestamp": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
Expand Down Expand Up @@ -170,42 +167,6 @@ def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets


def save_current_episode(dataset):
episode_index = dataset["num_episodes"]
ep_dict = dataset["current_episode"]
episodes_dir = dataset["episodes_dir"]
rec_info_path = dataset["rec_info_path"]

ep_dict["next.done"][-1] = True

for key in ep_dict:
if "observation" in key and "image" not in key:
ep_dict[key] = torch.stack(ep_dict[key])

ep_dict["action"] = torch.stack(ep_dict["action"])
ep_dict["next.reward"] = torch.tensor(ep_dict["next.reward"])
ep_dict["next.success"] = torch.tensor(ep_dict["next.success"])
ep_dict["seed"] = torch.tensor(ep_dict["seed"])
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])

ep_path = episodes_dir / f"episode_{episode_index}.pth"
torch.save(ep_dict, ep_path)

rec_info = {
"last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)

# force re-initialization of episode dictionnary during add_frame
del dataset["current_episode"]

dataset["num_episodes"] += 1


########################################################################################
# Control modes
########################################################################################
Expand All @@ -227,10 +188,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
def record(
env,
robot: Robot,
process_action_from_leader,
root: Path,
repo_id: str,
task: str,
process_action_from_leader: function | None = None,
fps: int | None = None,
tags: list[str] | None = None,
pretrained_policy_name_or_path: str = None,
Expand All @@ -244,10 +205,10 @@ def record(
display_cameras: bool = False,
play_sounds: bool = True,
resume: bool = False,
local_files_only: bool = False
local_files_only: bool = False,
run_compute_stats: bool = True,
) -> LeRobotDataset:

# Load pretrained policy
# Load pretrained policy
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
Expand All @@ -267,11 +228,11 @@ def record(

# Create empty dataset or load existing saved episodes
num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space])
features = DEFAULT_FEATURES

# get image keys
image_keys = [key for key in env.observation_space if "image" in key]
state_keys_dict = env_cfg.get("state_keys", {})
# non_image_keys = [key for key in env.observation_space if "image" not in key]
state_keys_dict = env_cfg.state_keys

if resume:
dataset = LeRobotDataset(
Expand All @@ -281,42 +242,43 @@ def record(
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(num_cameras),
num_threads=num_image_writer_threads_per_camera * num_cameras,
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
else:
features = DEFAULT_FEATURES
# add image keys to features
for key in image_keys:
shape = env.observation_space[key].shape
if not key.startswith("observation.image."):
key = "observation.image." + key
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}

for key, obs_key in state_keys_dict.items():
features[key] = {
"dtype": "float32",
"names": None,
"shape": env.observation_space[obs_key].shape,
}

features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}

# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
features=features,
use_videos=video,
tolerance_s=1e-1,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(num_cameras),
image_writer_threads=num_image_writer_threads_per_camera * num_cameras,
)

# add image keys to features
for key in image_keys:
shape = env.observation_space[key].shape
if not key.startswith("observation.image"):
key = "observation.image" + key
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}

for key in state_keys_dict:
features[key] = {
"dtype": "float32",
"names": None,
"shape": env.observation_space[state_keys_dict[key]],
}

recorded_episodes = 0
while True:
if dataset["num_episodes"] >= num_episodes:
break

episode_index = dataset["num_episodes"]
log_say(f"Recording episode {episode_index}", play_sounds)
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)

if events is None:
events = {"exit_early": False}
Expand All @@ -341,27 +303,31 @@ def record(

observation, reward, terminated, _, info = env.step(action)

action = {"action": torch.from_numpy(action)}
success = info.get("is_success", False)
env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps)

if dataset is not None:
frame = {
"action": torch.from_numpy(action),
"next.reward": reward,
"next.success": success,
"seed": seed,
}
frame = {
"action": torch.from_numpy(action) * 180 / np.pi,
"next.reward": reward,
"next.success": success,
"seed": seed,
"timestamp": env_timestamp,
}

for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.images." + key] = observation[key]
for key, obs_key in state_keys_dict.items():
frame[key] = torch.from_numpy(observation[obs_key])
dataset.add_frame(frame)
for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key]
else:
frame[key] = observation[key]

for key, obs_key in state_keys_dict.items():
frame[key] = torch.from_numpy(observation[obs_key])

dataset.add_frame(frame)

if display_cameras and not is_headless():
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR))
cv2.waitKey(1)

if fps is not None:
Expand All @@ -384,8 +350,9 @@ def record(
continue

dataset.save_episode(task=task)
recorded_episodes += 1

if events["stop_recording"]:
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info("Waiting for a few seconds before starting next episode recording...")
Expand All @@ -394,8 +361,9 @@ def record(
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)

logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats=True)
if run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats)

if push_to_hub:
dataset.push_to_hub(tags=tags)
Expand Down Expand Up @@ -477,6 +445,12 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
default=60,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--task",
type=str,
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
Expand All @@ -496,11 +470,6 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
nargs="*",
help="Add tags to your dataset on the hub.",
)
parser_record.add_argument(
"--task",
type=str,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument(
"--num-image-writer-processes",
type=int,
Expand All @@ -513,7 +482,7 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
),
)
parser_record.add_argument(
"--num-image-writers-per-camera",
"--num-image-writer-threads-per-camera",
type=int,
default=4,
help=(
Expand All @@ -528,7 +497,12 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
default=0,
help="Visualize image observations with opencv.",
)

parser_record.add_argument(
"--resume",
type=int,
default=0,
help="Resume recording on an existing dataset.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
Expand Down Expand Up @@ -569,6 +543,7 @@ def env_constructor():
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)

robot = None
process_leader_actions_fn = None

if control_mode in ["teleoperate", "record"]:
# make robot
Expand Down

0 comments on commit 0478b15

Please sign in to comment.