diff --git a/.cache/calibration/aloha_default/left_follower.json b/.cache/calibration/aloha_default/left_follower.json new file mode 100644 index 000000000..8521f4d4f --- /dev/null +++ b/.cache/calibration/aloha_default/left_follower.json @@ -0,0 +1 @@ +{"homing_offset": [2048, 3072, 3072, -1024, -1024, 2048, -2048, 2048, -2048], "drive_mode": [1, 1, 1, 0, 0, 1, 0, 1, 0], "start_pos": [2015, 3058, 3061, 1071, 1071, 2035, 2152, 2029, 2499], "end_pos": [-1008, -1963, -1966, 2141, 2143, -971, 3043, -1077, 3144], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], "motor_names": ["waist", "shoulder", "shoulder_shadow", "elbow", "elbow_shadow", "forearm_roll", "wrist_angle", "wrist_rotate", "gripper"]} diff --git a/.cache/calibration/aloha_default/left_leader.json b/.cache/calibration/aloha_default/left_leader.json new file mode 100644 index 000000000..9599dfbb2 --- /dev/null +++ b/.cache/calibration/aloha_default/left_leader.json @@ -0,0 +1 @@ +{"homing_offset": [2048, 3072, 3072, -1024, -1024, 2048, -2048, 2048, -1024], "drive_mode": [1, 1, 1, 0, 0, 1, 0, 1, 0], "start_pos": [2035, 3024, 3019, 979, 981, 1982, 2166, 2124, 1968], "end_pos": [-990, -2017, -2015, 2078, 2076, -1030, 3117, -1016, 2556], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], "motor_names": ["waist", "shoulder", "shoulder_shadow", "elbow", "elbow_shadow", "forearm_roll", "wrist_angle", "wrist_rotate", "gripper"]} diff --git a/.cache/calibration/aloha_default/right_follower.json b/.cache/calibration/aloha_default/right_follower.json new file mode 100644 index 000000000..f460ef4b9 --- /dev/null +++ b/.cache/calibration/aloha_default/right_follower.json @@ -0,0 +1 @@ +{"homing_offset": [2048, 3072, 3072, -1024, -1024, 2048, -2048, 2048, -2048], "drive_mode": [1, 1, 1, 0, 0, 1, 0, 1, 0], "start_pos": [2056, 2895, 2896, 1191, 1190, 2018, 2051, 2056, 2509], "end_pos": [-1040, -2004, -2006, 2126, 2127, -1010, 3050, -1117, 3143], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], "motor_names": ["waist", "shoulder", "shoulder_shadow", "elbow", "elbow_shadow", "forearm_roll", "wrist_angle", "wrist_rotate", "gripper"]} diff --git a/.cache/calibration/aloha_default/right_leader.json b/.cache/calibration/aloha_default/right_leader.json new file mode 100644 index 000000000..2c41c40c0 --- /dev/null +++ b/.cache/calibration/aloha_default/right_leader.json @@ -0,0 +1 @@ +{"homing_offset": [2048, 3072, 3072, -1024, -1024, 2048, -2048, 2048, -2048], "drive_mode": [1, 1, 1, 0, 0, 1, 0, 1, 0], "start_pos": [2068, 3034, 3030, 1038, 1041, 1991, 1948, 2090, 1985], "end_pos": [-1025, -2014, -2015, 2058, 2060, -955, 3091, -940, 2576], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], "motor_names": ["waist", "shoulder", "shoulder_shadow", "elbow", "elbow_shadow", "forearm_roll", "wrist_angle", "wrist_rotate", "gripper"]} diff --git a/.gitattributes b/.gitattributes index f12e709c4..f3e1f1164 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ +.cache/calibration/aloha_default/*.json -filter -diff -merge text *.memmap filter=lfs diff=lfs merge=lfs -text *.stl filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text diff --git a/docker/lerobot-cpu/Dockerfile b/docker/lerobot-cpu/Dockerfile index 7ad3bf6e8..34f5361a8 100644 --- a/docker/lerobot-cpu/Dockerfile +++ b/docker/lerobot-cpu/Dockerfile @@ -22,7 +22,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc COPY . /lerobot WORKDIR /lerobot RUN pip install --upgrade --no-cache-dir pip -RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]" \ +RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \ --extra-index-url https://download.pytorch.org/whl/cpu # Set EGL as the rendering backend for MuJoCo diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile index 4d6f772bf..92640cf4b 100644 --- a/docker/lerobot-gpu/Dockerfile +++ b/docker/lerobot-gpu/Dockerfile @@ -24,7 +24,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc COPY . /lerobot WORKDIR /lerobot RUN pip install --upgrade --no-cache-dir pip -RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]" +RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" # Set EGL as the rendering backend for MuJoCo ENV MUJOCO_GL="egl" diff --git a/examples/7_get_started_with_real_robot.md b/examples/7_get_started_with_real_robot.md index f738ec29a..a15c5c5ae 100644 --- a/examples/7_get_started_with_real_robot.md +++ b/examples/7_get_started_with_real_robot.md @@ -11,7 +11,7 @@ This tutorial will guide you through the process of setting up and training a ne By following these steps, you'll be able to replicate tasks like picking up a Lego block and placing it in a bin with a high success rate, as demonstrated in [this video](https://x.com/RemiCadene/status/1814680760592572934). -Although this tutorial is general and can be easily adapted to various types of robots by changing the configuration, it is specifically based on the [Koch v1.1](https://github.com/jess-moss/koch-v1-1), an affordable robot. The Koch v1.1 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot. +This tutorial is specifically made for the affordable [Koch v1.1](https://github.com/jess-moss/koch-v1-1) robot, but it contains additional information to be easily adapted to various types of robots like [Aloha bimanual robot](aloha-2.github.io) by changing some configurations. The Koch v1.1 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot. During the data collection phase, you will control the follower arm by moving the leader arm. This process is known as "teleoperation." This technique is used to collect robot trajectories. Afterward, you'll train a neural network to imitate these trajectories and deploy the network to enable your robot to operate autonomously. @@ -29,16 +29,16 @@ For a visual walkthrough of the assembly process, you can refer to [this video t ## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1 -First, install the additional dependencies required for Koch v1.1 by running one of the following commands. +First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands. Using `pip`: ```bash -pip install -e ".[koch]" +pip install -e ".[dynamixel]" ``` Or using `poetry`: ```bash -poetry install --sync --extras "koch" +poetry install --sync --extras "dynamixel" ``` You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V. @@ -147,6 +147,7 @@ follower_arm = DynamixelMotorsBus( Next, update the port values in the YAML configuration file for the Koch robot at [`lerobot/configs/robot/koch.yaml`](../lerobot/configs/robot/koch.yaml) with the ports you've identified: ```yaml [...] +robot_type: koch leader_arms: main: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus @@ -174,6 +175,8 @@ follower_arms: [...] ``` +Don't forget to set `robot_type: aloha` if you follow this tutorial with [Aloha bimanual robot](aloha-2.github.io) instead of Koch v1.1 + This configuration file is used to instantiate your robot across all scripts. We'll cover how this works later on. **Connect and Configure your Motors** @@ -298,32 +301,37 @@ Alternatively, you can unplug the power cord, which will automatically disable t */!\ Warning*: These motors tend to overheat, especially under torque or if left plugged in for too long. Unplug after use. -### b. Teleoperate your Koch v1.1 with KochRobot +### b. Teleoperate your Koch v1.1 with ManipulatorRobot -**Instantiate the KochRobot** +**Instantiate the ManipulatorRobot** -Before you can teleoperate your robot, you need to instantiate the [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py) using the previously defined `leader_arm` and `follower_arm`. +Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_arm` and `follower_arm`. -For the Koch robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_arm}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_arm, "right": right_leader_arm},`. Same thing for the follower arms. +For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_arm}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_arm, "right": right_leader_arm},`. Same thing for the follower arms. -You also need to provide a path to a calibration file, such as `calibration_path=".cache/calibration/koch.pkl"`. More on this in the next section. +You also need to provide a path to a calibration directory, such as `calibration_dir=".cache/calibration/koch"`. More on this in the next section. -Run the following code to instantiate your Koch robot: +Run the following code to instantiate your manipulator robot: ```python -from lerobot.common.robot_devices.robots.koch import KochRobot +from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot -robot = KochRobot( +robot = ManipulatorRobot( + robot_type="koch", leader_arms={"main": leader_arm}, follower_arms={"main": follower_arm}, - calibration_path=".cache/calibration/koch.pkl", + calibration_dir=".cache/calibration/koch", ) ``` -**Calibrate and Connect the KochRobot** +The `robot_type="koch"` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger. + +For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `robot_type="aloha"` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected. If you need to run manual calibration, simply update `calibration_dir` to `.cache/calibration/aloha`. + +**Calibrate and Connect the ManipulatorRobot** -Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Koch robot to work on another. +Next, you'll need to calibrate your Koch robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Koch robot to work on another. -When you connect your robot for the first time, the [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py) will detect if the calibration file is missing and trigger the calibration procedure. During this process, you will be guided to move each arm to three different positions. +When you connect your robot for the first time, the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) will detect if the calibration file is missing and trigger the calibration procedure. During this process, you will be guided to move each arm to three different positions. Here are the positions you'll move the follower arm to: @@ -354,27 +362,26 @@ The output will look like this: ``` Connecting main follower arm Connecting main leader arm -Missing calibration file '.cache/calibration/koch.pkl'. Starting calibration procedure. - -Running calibration of main follower... +Missing calibration file '.cache/calibration/koch/main_follower.json' +Running calibration of koch main follower... Move arm to zero position [...] Move arm to rotated position [...] Move arm to rest position [...] +Calibration is done! Saving calibration file '.cache/calibration/koch/main_follower.json' -Running calibration of main leader... - +Missing calibration file '.cache/calibration/koch/main_leader.json' +Running calibration of koch main leader... Move arm to zero position [...] Move arm to rotated position [...] Move arm to rest position [...] - -Calibration is done! Saving calibration file '.cache/calibration/koch.pkl' +Calibration is done! Saving calibration file '.cache/calibration/koch/main_leader.json' ``` *Verifying Calibration* @@ -414,7 +421,7 @@ for _ in tqdm.tqdm(range(seconds*frequency)): *Using `teleop_step` for Teleoperation* -Alternatively, you can teleoperate the robot using the `teleop_step` method from [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py). +Alternatively, you can teleoperate the robot using the `teleop_step` method from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py). Run this code to teleoperate: ```python @@ -607,10 +614,10 @@ Additionaly, you can set up your robot to work with your cameras. Modify the following Python code with the appropriate camera names and configurations: ```python -robot = KochRobot( +robot = ManipulatorRobot( leader_arms={"main": leader_arm}, follower_arms={"main": follower_arm}, - calibration_path=".cache/calibration/koch.pkl", + calibration_dir=".cache/calibration/koch", cameras={ "laptop": OpenCVCamera(0, fps=30, width=640, height=480), "phone": OpenCVCamera(1, fps=30, width=640, height=480), @@ -925,7 +932,7 @@ huggingface-cli upload ${HF_USER}/act_koch_test_${CKPT} \ ## 5. Evaluate your policy -Now that you have a policy checkpoint, you can easily control your robot with it using methods from [`KochRobot`](../lerobot/common/robot_devices/robots/koch.py) and the policy. +Now that you have a policy checkpoint, you can easily control your robot with it using methods from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) and the policy. Try this code for running inference for 60 seconds at 30 fps: ```python diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 65998e8b3..aeae31008 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -27,6 +27,7 @@ print(lerobot.available_real_world_datasets) print(lerobot.available_policies) print(lerobot.available_policies_per_env) + print(lerobot.available_robots) ``` When implementing a new dataset loadable with LeRobotDataset follow these steps: @@ -182,7 +183,7 @@ itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets) ) -# lists all available policies from `lerobot/common/policies` by their class attribute: `name`. +# lists all available policies from `lerobot/common/policies` available_policies = [ "act", "diffusion", @@ -190,6 +191,13 @@ "vqbet", ] +# lists all available robots from `lerobot/common/robot_devices/robots` +available_robots = [ + "koch", + "koch_bimanual", + "aloha", +] + # keys and values refer to yaml files available_policies_per_env = { "aloha": ["act"], diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index 5556dccea..c26de6c50 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -17,9 +17,12 @@ import numpy as np from PIL import Image -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, + busy_wait, +) from lerobot.common.utils.utils import capture_timestamp_utc -from lerobot.scripts.control_robot import busy_wait # Use 1 thread to avoid blocking the main thread. Especially useful during data collection # when other threads are used to save the images. diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py index 687db1746..491963fed 100644 --- a/lerobot/common/robot_devices/motors/dynamixel.py +++ b/lerobot/common/robot_devices/motors/dynamixel.py @@ -1,4 +1,6 @@ import enum +import logging +import math import time import traceback from copy import deepcopy @@ -27,11 +29,28 @@ MAX_ID_RANGE = 252 +# The following bounds define the lower and upper joints range (after calibration). +# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees +# which corresponds to a half rotation on the left and half rotation on the right. +# Some joints might require higher range, so we allow up to [-270, 270] degrees until +# an error is raised. +LOWER_BOUND_DEGREE = -270 +UPPER_BOUND_DEGREE = 270 +# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper), +# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully +# closed, and 100% is fully open. To account for slight calibration issue, we allow up to +# [-10, 110] until an error is raised. +LOWER_BOUND_LINEAR = -10 +UPPER_BOUND_LINEAR = 110 + +HALF_TURN_DEGREE = 180 + # https://emanual.robotis.com/docs/en/dxl/x/xl330-m077 # https://emanual.robotis.com/docs/en/dxl/x/xl330-m288 # https://emanual.robotis.com/docs/en/dxl/x/xl430-w250 # https://emanual.robotis.com/docs/en/dxl/x/xm430-w350 # https://emanual.robotis.com/docs/en/dxl/x/xm540-w270 +# https://emanual.robotis.com/docs/en/dxl/x/xc430-w150 # data_name: (address, size_byte) X_SERIES_CONTROL_TABLE = { @@ -109,6 +128,7 @@ "xl430-w250": X_SERIES_CONTROL_TABLE, "xm430-w350": X_SERIES_CONTROL_TABLE, "xm540-w270": X_SERIES_CONTROL_TABLE, + "xc430-w150": X_SERIES_CONTROL_TABLE, } MODEL_RESOLUTION = { @@ -118,6 +138,7 @@ "xl430-w250": 4096, "xm430-w350": 4096, "xm540-w270": 4096, + "xc430-w150": 4096, } MODEL_BAUDRATE_TABLE = { @@ -127,20 +148,18 @@ "xl430-w250": X_SERIES_BAUDRATE_TABLE, "xm430-w350": X_SERIES_BAUDRATE_TABLE, "xm540-w270": X_SERIES_BAUDRATE_TABLE, + "xc430-w150": X_SERIES_BAUDRATE_TABLE, } NUM_READ_RETRY = 10 NUM_WRITE_RETRY = 10 -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]): - """This function convert the degree range to the step range for indicating motors rotation. - It assums a motor achieves a full rotation by going from -180 degree position to +180. +def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: + """This function converts the degree range to the step range for indicating motors rotation. + It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. """ - if isinstance(degrees, float): - degrees = np.array(degrees) - resolutions = [MODEL_RESOLUTION[model] for model in models] steps = degrees / 180 * np.array(resolutions) / 2 steps = steps.astype(int) @@ -250,20 +269,24 @@ class TorqueMode(enum.Enum): DISABLED = 0 -class OperatingMode(enum.Enum): - VELOCITY = 1 - POSITION = 3 - EXTENDED_POSITION = 4 - CURRENT_CONTROLLED_POSITION = 5 - PWM = 16 - UNKNOWN = -1 - - class DriveMode(enum.Enum): NON_INVERTED = 0 INVERTED = 1 +class CalibrationMode(enum.Enum): + # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] + DEGREE = 0 + # Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100] + LINEAR = 1 + + +class JointOutOfRangeError(Exception): + def __init__(self, message="Joint is out of range"): + self.message = message + super().__init__(self.message) + + class DynamixelMotorsBus: # TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2 """ @@ -531,9 +554,22 @@ def motor_models(self) -> list[str]: def motor_indices(self) -> list[int]: return [idx for idx, _ in self.motors.values()] - def set_calibration(self, calibration: dict[str, tuple[int, bool]]): + def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration + def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. + + For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. + """ + try: + values = self.apply_calibration(values, motor_names) + except JointOutOfRangeError as e: + print(e) + self.autocorrect_calibration(values, motor_names) + values = self.apply_calibration(values, motor_names) + return values + def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -551,53 +587,197 @@ def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | if motor_names is None: motor_names = self.motor_names - # Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[ - values = values.astype(np.int32) + # Convert from unsigned int32 original range [0, 2**32] to signed float32 range + values = values.astype(np.float32) for i, name in enumerate(motor_names): - homing_offset, drive_mode = self.calibration[name] + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Update direction of rotation of the motor to match between leader and follower. + # In fact, the motor of the leader for a given joint can be assembled in an + # opposite direction in term of rotation than the motor of the follower on the same joint. + if drive_mode: + values[i] *= -1 + + # Convert from range [-2**31, 2**31] to + # nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048]) + values[i] += homing_offset + + # Convert from range [-resolution//2, resolution//2] to + # universal float32 centered degree range [-180, 180] + # (e.g. 2048 / (4096 // 2) * 180 = 180) + values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE + + if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE): + raise JointOutOfRangeError( + f"Wrong motor position range detected for {name}. " + f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), " + f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, " + f"but present value is {values[i]} degree. " + "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. " + "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" + ) + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Rescale the present position to a nominal range [0, 100] %, + # useful for joints with linear motions like Aloha gripper + values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100 + + if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR): + raise JointOutOfRangeError( + f"Wrong motor position range detected for {name}. " + f"Expected to be in nominal range of [0, 100] % (a full linear translation), " + f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, " + f"but present value is {values[i]} %. " + "This might be due to a cable connection issue creating an artificial jump in motor values. " + "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" + ) + + return values - # Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint - # can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint. - if drive_mode: - values[i] *= -1 + def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """This function automatically detects issues with values of motors after calibration, and correct for these issues. - # Convert from range [-2**31, 2**31[ to nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[) - values[i] += homing_offset + Some motors might have values outside of expected maximum bounds after calibration. + For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given + a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position. - # Convert from range ]-resolution, resolution[ to the universal float32 centered degree range ]-180, 180[ + Known issues: + #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors. + #2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha). + #3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration + or by human error during manual calibration. + + Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn. + Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`, + that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue. + + Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. + """ + if motor_names is None: + motor_names = self.motor_names + + # Convert from unsigned int32 original range [0, 2**32] to signed float32 range values = values.astype(np.float32) + for i, name in enumerate(motor_names): - _, model = self.motors[name] - resolution = self.model_resolution[model] - values[i] = values[i] / (resolution // 2) * 180 + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Update direction of rotation of the motor to match between leader and follower. + # In fact, the motor of the leader for a given joint can be assembled in an + # opposite direction in term of rotation than the motor of the follower on the same joint. + if drive_mode: + values[i] *= -1 + + # Convert from initial range to range [-180, 180] degrees + calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + + # Solve this inequality to find the factor to shift the range into [-180, 180] degrees + # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE + # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE + # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution + low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution + upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Convert from initial range to range [0, 100] in % + calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 + in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + + # Solve this inequality to find the factor to shift the range into [0, 100] % + # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 + # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 + # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100 + # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution + low_factor = (start_pos - values[i]) / resolution + upp_factor = (end_pos - values[i]) / resolution + + if not in_range: + # Get first integer between the two bounds + if low_factor < upp_factor: + factor = math.ceil(low_factor) + + if factor > upp_factor: + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + else: + factor = math.ceil(upp_factor) + + if factor > low_factor: + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" + in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + + logging.warning( + f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " + f"from '{out_of_range_str}' to '{in_range_str}'." + ) - return values + # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. + self.calibration["homing_offset"][calib_idx] += resolution * factor def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names - # Convert from the universal float32 centered degree range ]-180, 180[ to resolution range ]-resolution, resolution[ for i, name in enumerate(motor_names): - _, model = self.motors[name] - resolution = self.model_resolution[model] - values[i] = values[i] / 180 * (resolution // 2) + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] - values = np.round(values).astype(np.int32) + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] - # Convert from nominal range ]-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[ - for i, name in enumerate(motor_names): - homing_offset, drive_mode = self.calibration[name] - values[i] -= homing_offset + # Convert from nominal 0-centered degree range [-180, 180] to + # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) + values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) + + # Substract the homing offsets to come back to actual motor range of values + # which can be arbitrary. + values[i] -= homing_offset + + # Remove drive mode, which is the rotation direction of the motor, to come back to + # actual motor rotation direction which can be arbitrary. + if drive_mode: + values[i] *= -1 + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] - # Update direction of rotation of the motor that was matching between leader and follower to their original direction. - # In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation - # than the motor of the follower on the same joint. - if drive_mode: - values[i] *= -1 + # Convert from nominal lnear range of [0, 100] % to + # actual motor range of values which can be arbitrary. + values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos + values = np.round(values).astype(np.int32) return values def _read_with_motor_ids(self, motor_models, motor_ids, data_name): @@ -683,19 +863,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None): values = values.astype(np.int32) if data_name in CALIBRATION_REQUIRED and self.calibration is not None: - values = self.apply_calibration(values, motor_names) - - # We expect our motors to stay in a nominal range of [-180, 180] degrees - # which corresponds to a half turn rotation. - # However, some motors can turn a bit more, hence we extend the nominal range to [-270, 270] - # which is less than a full 360 degree rotation. - if not np.all((values > -270) & (values < 270)): - raise ValueError( - f"Wrong motor position range detected. " - f"Expected to be in [-270, +270] but in [{values.min()}, {values.max()}]. " - "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. " - "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" - ) + values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) diff --git a/lerobot/common/robot_devices/robots/koch.py b/lerobot/common/robot_devices/robots/manipulator.py similarity index 53% rename from lerobot/common/robot_devices/robots/koch.py rename to lerobot/common/robot_devices/robots/manipulator.py index 9d5858656..c4b2c431c 100644 --- a/lerobot/common/robot_devices/robots/koch.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -1,6 +1,7 @@ +import json import logging -import pickle import time +import warnings from dataclasses import dataclass, field, replace from pathlib import Path from typing import Sequence @@ -10,11 +11,12 @@ from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.motors.dynamixel import ( - OperatingMode, + CalibrationMode, TorqueMode, convert_degrees_to_steps, ) from lerobot.common.robot_devices.motors.utils import MotorsBus +from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError ######################################################################## @@ -25,7 +27,8 @@ "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" ) -# In nominal degree range ]-180, +180[ +# The following positions are provided in nominal degree range ]-180, +180[ +# For more info on these constants, see comments in the code where they get used. ZERO_POSITION_DEGREE = 0 ROTATED_POSITION_DEGREE = 90 @@ -45,27 +48,13 @@ def apply_drive_mode(position, drive_mode): return position -def reset_torque_mode(arm: MotorsBus): - # To be configured, all servos must be in "torque disable" mode - arm.write("Torque_Enable", TorqueMode.DISABLED.value) +def compute_nearest_rounded_position(position, models): + delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models) + nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn + return nearest_pos.astype(position.dtype) - # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't - # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, - # you could end up with a servo with a position 0 or 4095 at a crucial point See [ - # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] - all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] - if len(all_motors_except_gripper) > 0: - arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper) - # Use 'position control current based' for gripper to be limited by the limit of the current. - # For the follower gripper, it means it can grasp an object without forcing too much even tho, - # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). - # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger - # to make it move, and it will move back to its original target position when we release the force. - arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper") - - -def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str): +def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): """This function ensures that a neural network trained on data collected on a given robot can work on another robot. For instance before calibration, setting a same goal position for each motor of two different robots will get two very different positions. But after calibration, @@ -84,38 +73,27 @@ def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str): Example of usage: ```python - run_arm_calibration(arm, "left", "follower") + run_arm_calibration(arm, "koch", "left", "follower") ``` """ - reset_torque_mode(arm) + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run calibration, the torque must be disabled on all motors.") - print(f"\nRunning calibration of {name} {arm_type}...") + print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="zero")) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) input("Press Enter to continue...") - # We arbitrarely choosed our zero target position to be a straight horizontal position with gripper upwards and closed. + # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will - # corresponds to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position. - zero_position = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models) - - def _compute_nearest_rounded_position(position, models): - # TODO(rcadene): Rework this function since some motors cant physically rotate a quarter turn - # (e.g. the gripper of Aloha arms can only rotate ~50 degree) - quarter_turn_degree = 90 - quarter_turn = convert_degrees_to_steps(quarter_turn_degree, models) - nearest_pos = np.round(position.astype(float) / quarter_turn) * quarter_turn - return nearest_pos.astype(position.dtype) + # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position. + zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models) # Compute homing offset so that `present_position + homing_offset ~= target_position`. - position = arm.read("Present_Position") - position = _compute_nearest_rounded_position(position, arm.motor_models) - homing_offset = zero_position - position - - print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rotated")) - input("Press Enter to continue...") + zero_pos = arm.read("Present_Position") + zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models) + homing_offset = zero_target_pos - zero_nearest_pos # The rotated target position corresponds to a rotation of a quarter turn from the zero position. # This allows to identify the rotation direction of each motor. @@ -124,44 +102,83 @@ def _compute_nearest_rounded_position(position, models): # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view # of the previous motor in the kinetic chain. - rotated_position = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + print("\nMove arm to rotated target position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + input("Press Enter to continue...") + + rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). - position = arm.read("Present_Position") - position += homing_offset - position = _compute_nearest_rounded_position(position, arm.motor_models) - drive_mode = (position != rotated_position).astype(np.int32) + rotated_pos = arm.read("Present_Position") + drive_mode = (rotated_pos < zero_pos).astype(np.int32) # Re-compute homing offset to take into account drive mode - position = arm.read("Present_Position") - position = apply_drive_mode(position, drive_mode) - position = _compute_nearest_rounded_position(position, arm.motor_models) - homing_offset = rotated_position - position + rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) + rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) + homing_offset = rotated_target_pos - rotated_nearest_pos print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rest")) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) input("Press Enter to continue...") print() - return homing_offset, drive_mode + # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] + calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names) + + # TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml? + if robot_type == "aloha" and "gripper" in arm.motor_names: + # Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100] + calib_idx = arm.motor_names.index("gripper") + calib_mode[calib_idx] = CalibrationMode.LINEAR.name + + calib_data = { + "homing_offset": homing_offset.tolist(), + "drive_mode": drive_mode.tolist(), + "start_pos": zero_pos.tolist(), + "end_pos": rotated_pos.tolist(), + "calib_mode": calib_mode, + "motor_names": arm.motor_names, + } + return calib_data + + +def ensure_safe_goal_position( + goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] +): + # Cap relative action target magnitude for safety. + diff = goal_pos - present_pos + max_relative_target = torch.tensor(max_relative_target) + safe_diff = torch.minimum(diff, max_relative_target) + safe_diff = torch.maximum(safe_diff, -max_relative_target) + safe_goal_pos = present_pos + safe_diff + + if not torch.allclose(goal_pos, safe_goal_pos): + logging.warning( + "Relative goal position magnitude had to be clamped to be safe.\n" + f" requested relative goal position target: {diff}\n" + f" clamped relative goal position target: {safe_diff}" + ) + + return safe_goal_pos ######################################################################## -# Alexander Koch robot arm +# Manipulator robot ######################################################################## @dataclass -class KochRobotConfig: +class ManipulatorRobotConfig: """ Example of usage: ```python - KochRobotConfig() + ManipulatorRobotConfig() ``` """ # Define all components of the robot + robot_type: str | None = None leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {}) follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {}) cameras: dict[str, Camera] = field(default_factory=lambda: {}) @@ -191,14 +208,15 @@ def __setattr__(self, prop: str, val): super().__setattr__(prop, val) -class KochRobot: +class ManipulatorRobot: # TODO(rcadene): Implement force feedback - """This class allows to control any Koch robot of various number of motors. + """This class allows to control any manipulator robot of various number of motors. - A few versions are available: - - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, which was developed - by Alexander Koch from [Tau Robotics](https://tau-robotics.com): [Github for sourcing and assembly]( - - [Koch v1.1])https://github.com/jess-moss/koch-v1-1), which was developed by Jess Moss. + Non exaustive list of robots: + - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed + by Alexander Koch from [Tau Robotics](https://tau-robotics.com) + - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss + - [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics Example of highest frequency teleoperation without camera: ```python @@ -231,7 +249,9 @@ class KochRobot: }, ), } - robot = KochRobot( + robot = ManipulatorRobot( + robot_type="koch", + calibration_dir=".cache/calibration/koch", leader_arms=leader_arms, follower_arms=follower_arms, ) @@ -246,7 +266,9 @@ class KochRobot: Example of highest frequency data collection without camera: ```python # Assumes leader and follower arms have been instantiated already (see first example) - robot = KochRobot( + robot = ManipulatorRobot( + robot_type="koch", + calibration_dir=".cache/calibration/koch", leader_arms=leader_arms, follower_arms=follower_arms, ) @@ -267,7 +289,9 @@ class KochRobot: } # Assumes leader and follower arms have been instantiated already (see first example) - robot = KochRobot( + robot = ManipulatorRobot( + robot_type="koch", + calibration_dir=".cache/calibration/koch", leader_arms=leader_arms, follower_arms=follower_arms, cameras=cameras, @@ -280,7 +304,9 @@ class KochRobot: Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency): ```python # Assumes leader and follower arms + cameras have been instantiated already (see previous example) - robot = KochRobot( + robot = ManipulatorRobot( + robot_type="koch", + calibration_dir=".cache/calibration/koch", leader_arms=leader_arms, follower_arms=follower_arms, cameras=cameras, @@ -306,16 +332,17 @@ class KochRobot: def __init__( self, - config: KochRobotConfig | None = None, - calibration_path: Path = ".cache/calibration/koch.pkl", + config: ManipulatorRobotConfig | None = None, + calibration_dir: Path = ".cache/calibration/koch", **kwargs, ): if config is None: - config = KochRobotConfig() + config = ManipulatorRobotConfig() # Overwrite config arguments using kwargs self.config = replace(config, **kwargs) - self.calibration_path = Path(calibration_path) + self.calibration_dir = Path(calibration_dir) + self.robot_type = self.config.robot_type self.leader_arms = self.config.leader_arms self.follower_arms = self.config.follower_arms self.cameras = self.config.cameras @@ -325,12 +352,12 @@ def __init__( def connect(self): if self.is_connected: raise RobotDeviceAlreadyConnectedError( - "KochRobot is already connected. Do not run `robot.connect()` twice." + "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." ) if not self.leader_arms and not self.follower_arms and not self.cameras: raise ValueError( - "KochRobot doesn't have any device to connect. See example of usage in docstring of the class." + "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." ) # Connect the arms @@ -340,38 +367,22 @@ def connect(self): print(f"Connecting {name} leader arm.") self.leader_arms[name].connect() - # Reset the arms and load or run calibration - if self.calibration_path.exists(): - # Reset all arms before setting calibration - for name in self.follower_arms: - reset_torque_mode(self.follower_arms[name]) - for name in self.leader_arms: - reset_torque_mode(self.leader_arms[name]) - - with open(self.calibration_path, "rb") as f: - calibration = pickle.load(f) - else: - print(f"Missing calibration file '{self.calibration_path}'. Starting calibration precedure.") - # Run calibration process which begins by reseting all arms - calibration = self.run_calibration() - - print(f"Calibration is done! Saving calibration file '{self.calibration_path}'") - self.calibration_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.calibration_path, "wb") as f: - pickle.dump(calibration, f) - - # Set calibration + # We assume that at connection time, arms are in a rest position, and torque can + # be safely disabled to run calibration and/or set robot preset configurations. for name in self.follower_arms: - self.follower_arms[name].set_calibration(calibration[f"follower_{name}"]) + self.follower_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value) for name in self.leader_arms: - self.leader_arms[name].set_calibration(calibration[f"leader_{name}"]) + self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value) - # Set better PID values to close the gap between recored states and actions - # TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor - for name in self.follower_arms: - self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex") - self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex") - self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex") + self.activate_calibration() + + # Set robot preset (e.g. torque in leader gripper for Koch v1.1) + if self.robot_type == "koch": + self.set_koch_robot_preset() + elif self.robot_type == "aloha": + self.set_aloha_robot_preset() + else: + warnings.warn(f"No preset found for robot type: {self.robot_type}", stacklevel=1) # Enable torque on all motors of the follower arms for name in self.follower_arms: @@ -391,31 +402,121 @@ def connect(self): self.is_connected = True - def run_calibration(self): - calibration = {} + def activate_calibration(self): + """After calibration all motors function in human interpretable ranges. + Rotations are expressed in degrees in nominal range of [-180, 180], + and linear motions (like gripper of Aloha) in nominal range of [0, 100]. + """ + + def load_or_run_calibration_(name, arm, arm_type): + arm_id = get_arm_id(name, arm_type) + arm_calib_path = self.calibration_dir / f"{arm_id}.json" + + if arm_calib_path.exists(): + with open(arm_calib_path) as f: + calibration = json.load(f) + else: + print(f"Missing calibration file '{arm_calib_path}'") + calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) + + print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") + arm_calib_path.parent.mkdir(parents=True, exist_ok=True) + with open(arm_calib_path, "w") as f: + json.dump(calibration, f) + + return calibration + + for name, arm in self.follower_arms.items(): + calibration = load_or_run_calibration_(name, arm, "follower") + arm.set_calibration(calibration) + for name, arm in self.leader_arms.items(): + calibration = load_or_run_calibration_(name, arm, "leader") + arm.set_calibration(calibration) + + def set_koch_robot_preset(self): + def set_operating_mode_(arm): + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run set robot preset, the torque must be disabled on all motors.") + + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't + # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, + # you could end up with a servo with a position 0 or 4095 at a crucial point See [ + # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] + all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] + if len(all_motors_except_gripper) > 0: + # 4 corresponds to Extended Position on Koch motors + arm.write("Operating_Mode", 4, all_motors_except_gripper) + + # Use 'position control current based' for gripper to be limited by the limit of the current. + # For the follower gripper, it means it can grasp an object without forcing too much even tho, + # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger + # to make it move, and it will move back to its original target position when we release the force. + # 5 corresponds to Current Controlled Position on Koch gripper motors "xl330-m077, xl330-m288" + arm.write("Operating_Mode", 5, "gripper") for name in self.follower_arms: - homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], name, "follower") + set_operating_mode_(self.follower_arms[name]) + + # Set better PID values to close the gap between recorded states and actions + # TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor + self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex") + self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex") + self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex") - calibration[f"follower_{name}"] = {} - for idx, motor_name in enumerate(self.follower_arms[name].motor_names): - calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx]) + if self.config.gripper_open_degree is not None: + for name in self.leader_arms: + set_operating_mode_(self.leader_arms[name]) + + # Enable torque on the gripper of the leader arms, and move it to 45 degrees, + # so that we can use it as a trigger to close the gripper of the follower arms. + self.leader_arms[name].write("Torque_Enable", 1, "gripper") + self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + + def set_aloha_robot_preset(self): + def set_shadow_(arm): + # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. + # As a result, if only one of them is required to move to a certain position, + # the other will follow. This is to avoid breaking the motors. + if "shoulder_shadow" in arm.motor_names: + shoulder_idx = arm.read("ID", "shoulder") + arm.write("Secondary_ID", shoulder_idx, "shoulder_shadow") + + if "elbow_shadow" in arm.motor_names: + elbow_idx = arm.read("ID", "elbow") + arm.write("Secondary_ID", elbow_idx, "elbow_shadow") + + for name in self.follower_arms: + set_shadow_(self.follower_arms[name]) for name in self.leader_arms: - homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], name, "leader") + set_shadow_(self.leader_arms[name]) + + for name in self.follower_arms: + # Set a velocity limit of 131 as advised by Trossen Robotics + self.follower_arms[name].write("Velocity_Limit", 131) + + # Use 'position control current based' for follower gripper to be limited by the limit of the current. + # It can grasp an object without forcing too much even tho, + # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # 5 corresponds to Current Controlled Position on Aloha gripper follower "xm430-w350" + self.follower_arms[name].write("Operating_Mode", 5, "gripper") - calibration[f"leader_{name}"] = {} - for idx, motor_name in enumerate(self.leader_arms[name].motor_names): - calibration[f"leader_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx]) + # Note: We can't enable torque on the leader gripper since "xc430-w150" doesn't have + # a Current Controlled Position mode. - return calibration + if self.config.gripper_open_degree is not None: + warnings.warn( + f"`gripper_open_degree` is set to {self.config.gripper_open_degree}, but None is expected for Aloha instead", + stacklevel=1, + ) def teleop_step( self, record_data=False ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: if not self.is_connected: raise RobotDeviceNotConnectedError( - "KochRobot is not connected. You need to run `robot.connect()`." + "ManipulatorRobot is not connected. You need to run `robot.connect()`." ) # Prepare to assign the position of the leader to the follower @@ -423,16 +524,27 @@ def teleop_step( for name in self.leader_arms: before_lread_t = time.perf_counter() leader_pos[name] = self.leader_arms[name].read("Present_Position") + leader_pos[name] = torch.from_numpy(leader_pos[name]) self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t + # Send goal position to the follower follower_goal_pos = {} - for name in self.leader_arms: - follower_goal_pos[name] = leader_pos[name] - - # Send action for name in self.follower_arms: before_fwrite_t = time.perf_counter() - self.send_action(torch.tensor(follower_goal_pos[name]), [name]) + goal_pos = leader_pos[name] + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.follower_arms[name].read("Present_Position") + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Used when record_data=True + follower_goal_pos[name] = goal_pos + + goal_pos = goal_pos.numpy().astype(np.int32) + self.follower_arms[name].write("Goal_Position", goal_pos) self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t # Early exit when recording data is not requested @@ -445,6 +557,7 @@ def teleop_step( for name in self.follower_arms: before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") + follower_pos[name] = torch.from_numpy(follower_pos[name]) self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t # Create state by concatenating follower current position @@ -452,29 +565,30 @@ def teleop_step( for name in self.follower_arms: if name in follower_pos: state.append(follower_pos[name]) - state = np.concatenate(state) + state = torch.cat(state) # Create action by concatenating follower goal position action = [] for name in self.follower_arms: if name in follower_goal_pos: action.append(follower_goal_pos[name]) - action = np.concatenate(action) + action = torch.cat(action) # Capture images from cameras images = {} for name in self.cameras: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - # Populate output dictionnaries and format to pytorch + # Populate output dictionnaries obs_dict, action_dict = {}, {} - obs_dict["observation.state"] = torch.from_numpy(state) - action_dict["action"] = torch.from_numpy(action) + obs_dict["observation.state"] = state + action_dict["action"] = action for name in self.cameras: - obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name]) + obs_dict[f"observation.images.{name}"] = images[name] return obs_dict, action_dict @@ -482,7 +596,7 @@ def capture_observation(self): """The returned observations do not have a batch dimension.""" if not self.is_connected: raise RobotDeviceNotConnectedError( - "KochRobot is not connected. You need to run `robot.connect()`." + "ManipulatorRobot is not connected. You need to run `robot.connect()`." ) # Read follower position @@ -490,6 +604,7 @@ def capture_observation(self): for name in self.follower_arms: before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") + follower_pos[name] = torch.from_numpy(follower_pos[name]) self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t # Create state by concatenating follower current position @@ -497,82 +612,68 @@ def capture_observation(self): for name in self.follower_arms: if name in follower_pos: state.append(follower_pos[name]) - state = np.concatenate(state) + state = torch.cat(state) # Capture images from cameras images = {} for name in self.cameras: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t # Populate output dictionnaries and format to pytorch obs_dict = {} - obs_dict["observation.state"] = torch.from_numpy(state) + obs_dict["observation.state"] = state for name in self.cameras: - obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name]) + obs_dict[f"observation.images.{name}"] = images[name] return obs_dict - def send_action(self, action: torch.Tensor, follower_names: list[str] | None = None): + def send_action(self, action: torch.Tensor) -> torch.Tensor: """Command the follower arms to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter - `max_relative_target`. + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. Args: - action: tensor containing the concatenated joint positions for the follower arms. - follower_names: Pass follower arm names to only control a subset of all the follower arms. + action: tensor containing the concatenated goal positions for the follower arms. """ if not self.is_connected: raise RobotDeviceNotConnectedError( - "KochRobot is not connected. You need to run `robot.connect()`." - ) - - if follower_names is None: - follower_names = list(self.follower_arms) - elif not set(follower_names).issubset(self.follower_arms): - raise ValueError( - f"You provided {follower_names=} but only the following arms are registered: " - f"{list(self.follower_arms)}" + "ManipulatorRobot is not connected. You need to run `robot.connect()`." ) from_idx = 0 to_idx = 0 - follower_goal_pos = {} - for name in follower_names: + action_sent = [] + for name in self.follower_arms: + # Get goal position of each follower arm by splitting the action vector to_idx += len(self.follower_arms[name].motor_names) - this_action = action[from_idx:to_idx] + goal_pos = action[from_idx:to_idx] + from_idx = to_idx + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: - if not isinstance(self.config.max_relative_target, list): - max_relative_target = [self.config.max_relative_target for _ in range(from_idx, to_idx)] - max_relative_target = torch.tensor(self.config.max_relative_target) - # Cap relative action target magnitude for safety. - current_pos = torch.tensor(self.follower_arms[name].read("Present_Position")) - diff = this_action - current_pos - safe_diff = torch.minimum(diff, max_relative_target) - safe_diff = torch.maximum(safe_diff, -max_relative_target) - safe_action = current_pos + safe_diff - if not torch.allclose(safe_action, this_action): - logging.warning( - "Relative action magnitude had to be clamped to be safe.\n" - f" requested relative action target: {diff}\n" - f" clamped relative action target: {safe_diff}" - ) - follower_goal_pos[name] = safe_action.numpy() - else: - follower_goal_pos[name] = this_action.numpy() + present_pos = self.follower_arms[name].read("Present_Position") + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) - from_idx = to_idx + # Save tensor to concat and return + action_sent.append(goal_pos) - for name in self.follower_arms: - self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32)) + # Send goal position to each follower + goal_pos = goal_pos.numpy().astype(np.int32) + self.follower_arms[name].write("Goal_Position", goal_pos) + + return torch.cat(action_sent) def disconnect(self): if not self.is_connected: raise RobotDeviceNotConnectedError( - "KochRobot is not connected. You need to run `robot.connect()` before disconnecting." + "ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting." ) for name in self.follower_arms: diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py index 0262b307e..122155f78 100644 --- a/lerobot/common/robot_devices/robots/utils.py +++ b/lerobot/common/robot_devices/robots/utils.py @@ -1,6 +1,13 @@ from typing import Protocol +def get_arm_id(name, arm_type): + """Returns the string identifier of a robot arm. For instance, for a bimanual manipulator + like Aloha, it could be left_follower, right_follower, left_leader, or right_leader. + """ + return f"{name}_{arm_type}" + + class Robot(Protocol): def init_teleop(self): ... def run_calibration(self): ... diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 792916730..79724af96 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -1,3 +1,15 @@ +import time + + +def busy_wait(seconds): + # Significantly more accurate than `time.sleep`, and mandatory for our use case, + # but it consumes CPU cycles. + # TODO(rcadene): find an alternative: from python 11, time.sleep is precise + end_time = time.perf_counter() + seconds + while time.perf_counter() < end_time: + pass + + class RobotDeviceNotConnectedError(Exception): """Exception raised when the robot device is not connected.""" diff --git a/lerobot/configs/robot/aloha.yaml b/lerobot/configs/robot/aloha.yaml new file mode 100644 index 000000000..d8366acdc --- /dev/null +++ b/lerobot/configs/robot/aloha.yaml @@ -0,0 +1,107 @@ +_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot +robot_type: aloha +# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been +# properly assembled, no manual calibration step is expected. If you need to run manual calibration, +# simply update this path to ".cache/calibration/aloha" +calibration_dir: .cache/calibration/aloha_default + +# /!\ FOR SAFETY, READ THIS /!\ +# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. +# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as +# the number of motors in your follower arms. +# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default. +# When you feel more confident with teleoperation or running the policy, you can extend +# this safety limit and even removing it by setting it to `null`. +# Also, everything is expected to work safely out-of-the-box, but we highly advise to +# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml), +# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully +max_relative_target: 5 + +leader_arms: + left: + _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus + port: /dev/ttyDXL_leader_left + motors: # window_x + # name: (index, model) + waist: [1, xm430-w350] + shoulder: [2, xm430-w350] + shoulder_shadow: [3, xm430-w350] + elbow: [4, xm430-w350] + elbow_shadow: [5, xm430-w350] + forearm_roll: [6, xm430-w350] + wrist_angle: [7, xm430-w350] + wrist_rotate: [8, xl430-w250] + gripper: [9, xc430-w150] + right: + _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus + port: /dev/ttyDXL_leader_right + motors: # window_x + # name: (index, model) + waist: [1, xm430-w350] + shoulder: [2, xm430-w350] + shoulder_shadow: [3, xm430-w350] + elbow: [4, xm430-w350] + elbow_shadow: [5, xm430-w350] + forearm_roll: [6, xm430-w350] + wrist_angle: [7, xm430-w350] + wrist_rotate: [8, xl430-w250] + gripper: [9, xc430-w150] + +follower_arms: + left: + _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus + port: /dev/ttyDXL_follower_left + motors: + # name: [index, model] + waist: [1, xm540-w270] + shoulder: [2, xm540-w270] + shoulder_shadow: [3, xm540-w270] + elbow: [4, xm540-w270] + elbow_shadow: [5, xm540-w270] + forearm_roll: [6, xm540-w270] + wrist_angle: [7, xm540-w270] + wrist_rotate: [8, xm430-w350] + gripper: [9, xm430-w350] + right: + _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus + port: /dev/ttyDXL_follower_right + motors: + # name: [index, model] + waist: [1, xm540-w270] + shoulder: [2, xm540-w270] + shoulder_shadow: [3, xm540-w270] + elbow: [4, xm540-w270] + elbow_shadow: [5, xm540-w270] + forearm_roll: [6, xm540-w270] + wrist_angle: [7, xm540-w270] + wrist_rotate: [8, xm430-w350] + gripper: [9, xm430-w350] + +# Troubleshooting: If one of your IntelRealSense cameras freeze during +# data recording due to bandwidth limit, you might need to plug the camera +# on another USB hub or PCIe card. +cameras: + cam_high: + _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera + camera_index: 10 + fps: 30 + width: 640 + height: 480 + cam_low: + _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera + camera_index: 22 + fps: 30 + width: 640 + height: 480 + cam_left_wrist: + _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera + camera_index: 16 + fps: 30 + width: 640 + height: 480 + cam_right_wrist: + _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera + camera_index: 4 + fps: 30 + width: 640 + height: 480 diff --git a/lerobot/configs/robot/koch.yaml b/lerobot/configs/robot/koch.yaml index d40d5ff38..40969dc73 100644 --- a/lerobot/configs/robot/koch.yaml +++ b/lerobot/configs/robot/koch.yaml @@ -1,5 +1,12 @@ -_target_: lerobot.common.robot_devices.robots.koch.KochRobot -calibration_path: .cache/calibration/koch.pkl +_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot +robot_type: koch +calibration_dir: .cache/calibration/koch + +# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. +# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as +# the number of motors in your follower arms. +max_relative_target: null + leader_arms: main: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus @@ -12,6 +19,7 @@ leader_arms: wrist_flex: [4, "xl330-m077"] wrist_roll: [5, "xl330-m077"] gripper: [6, "xl330-m077"] + follower_arms: main: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus @@ -24,6 +32,7 @@ follower_arms: wrist_flex: [4, "xl330-m288"] wrist_roll: [5, "xl330-m288"] gripper: [6, "xl330-m288"] + cameras: laptop: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera @@ -37,10 +46,8 @@ cameras: fps: 30 width: 640 height: 480 -# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. -# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as -# the number of motors in your follower arms. -max_relative_target: null + +# ~ Koch specific settings ~ # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible # to squeeze the gripper and have it spring back to an open position on its own. gripper_open_degree: 35.156 diff --git a/lerobot/configs/robot/koch_bimanual.yaml b/lerobot/configs/robot/koch_bimanual.yaml index 4a803d265..7f8138675 100644 --- a/lerobot/configs/robot/koch_bimanual.yaml +++ b/lerobot/configs/robot/koch_bimanual.yaml @@ -1,5 +1,12 @@ -_target_: lerobot.common.robot_devices.robots.koch.KochRobot -calibration_path: .cache/calibration/koch_bimanual.pkl +_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot +robot_type: koch +calibration_dir: .cache/calibration/koch_bimanual + +# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. +# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as +# the number of motors in your follower arms. +max_relative_target: null + leader_arms: left: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus @@ -23,6 +30,7 @@ leader_arms: wrist_flex: [4, "xl330-m077"] wrist_roll: [5, "xl330-m077"] gripper: [6, "xl330-m077"] + follower_arms: left: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus @@ -46,6 +54,7 @@ follower_arms: wrist_flex: [4, "xl330-m288"] wrist_roll: [5, "xl330-m288"] gripper: [6, "xl330-m288"] + cameras: laptop: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera @@ -59,10 +68,8 @@ cameras: fps: 30 width: 640 height: 480 -# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. -# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as -# the number of motors in your follower arms. -max_relative_target: null + +# ~ Koch specific settings ~ # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible # to squeeze the gripper and have it spring back to an open position on its own. gripper_open_degree: 35.156 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 56321e768..9daf2c051 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -127,7 +127,8 @@ from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id +from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.scripts.eval import get_pretrained_policy_path from lerobot.scripts.push_dataset_to_hub import ( @@ -169,15 +170,6 @@ def save_image(img_tensor, key, frame_index, episode_index, videos_dir): img.save(str(path), quality=100) -def busy_wait(seconds): - # Significantly more accurate than `time.sleep`, and mendatory for our use case, - # but it consumes CPU cycles. - # TODO(rcadene): find an alternative: from python 11, time.sleep is precise - end_time = time.perf_counter() + seconds - while time.perf_counter() < end_time: - pass - - def none_or_int(value): if value == "None": return None @@ -249,10 +241,38 @@ def is_headless(): ######################################################################################## -def calibrate(robot: Robot): - if robot.calibration_path.exists(): - print(f"Removing '{robot.calibration_path}'") - robot.calibration_path.unlink() +def calibrate(robot: Robot, arms: list[str] | None): + available_arms = [] + for name in robot.follower_arms: + arm_id = get_arm_id(name, "follower") + available_arms.append(arm_id) + for name in robot.leader_arms: + arm_id = get_arm_id(name, "leader") + available_arms.append(arm_id) + + unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms] + + available_arms_str = " ".join(available_arms) + unknown_arms_str = " ".join(unknown_arms) + + if arms is None or len(arms) == 0: + raise ValueError( + "No arm provided. Use `--arms` as argument with one or more available arms.\n" + f"For instance, to recalibrate all arms add: `--arms {available_arms_str}`" + ) + + if len(unknown_arms) > 0: + raise ValueError( + f"Unknown arms provided ('{unknown_arms_str}'). Available arms are `{available_arms_str}`." + ) + + for arm_id in arms: + arm_calib_path = robot.calibration_dir / f"{arm_id}.json" + if arm_calib_path.exists(): + print(f"Removing '{arm_calib_path}'") + arm_calib_path.unlink() + else: + print(f"Calibration file not found '{arm_calib_path}'") if robot.is_connected: robot.disconnect() @@ -260,6 +280,8 @@ def calibrate(robot: Robot): # Calling `connect` automatically runs calibration # when the calibration file is missing robot.connect() + robot.disconnect() + print("Calibration is done! You can now teleoperate and record datasets!") def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None): @@ -486,8 +508,11 @@ def on_press(key): action = action.to("cpu") # Order the robot to move - robot.send_action(action) - action = {"action": action} + action_sent = robot.send_action(action) + + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = {"action": action_sent} for key in action: if key not in ep_dict: @@ -712,6 +737,12 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo ) parser_calib = subparsers.add_parser("calibrate", parents=[base_parser]) + parser_calib.add_argument( + "--arms", + type=int, + nargs="*", + help="List of arms to calibrate (e.g. `--arms left_follower right_follower left_leader`)", + ) parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser]) parser_teleop.add_argument( diff --git a/media/aloha/follower_rest.webp b/media/aloha/follower_rest.webp new file mode 100644 index 000000000..03698acd6 Binary files /dev/null and b/media/aloha/follower_rest.webp differ diff --git a/media/aloha/follower_rotated.webp b/media/aloha/follower_rotated.webp new file mode 100644 index 000000000..914958bbc Binary files /dev/null and b/media/aloha/follower_rotated.webp differ diff --git a/media/aloha/follower_zero.webp b/media/aloha/follower_zero.webp new file mode 100644 index 000000000..c14c516cc Binary files /dev/null and b/media/aloha/follower_zero.webp differ diff --git a/media/aloha/leader_rest.webp b/media/aloha/leader_rest.webp new file mode 100644 index 000000000..821fdf7b3 Binary files /dev/null and b/media/aloha/leader_rest.webp differ diff --git a/media/aloha/leader_rotated.webp b/media/aloha/leader_rotated.webp new file mode 100644 index 000000000..ed4a3faa7 Binary files /dev/null and b/media/aloha/leader_rotated.webp differ diff --git a/media/aloha/leader_zero.webp b/media/aloha/leader_zero.webp new file mode 100644 index 000000000..b67cfa773 Binary files /dev/null and b/media/aloha/leader_zero.webp differ diff --git a/poetry.lock b/poetry.lock index 56a318ef8..8b8350fc3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -2406,6 +2406,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"}, {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"}, {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"}, ] @@ -4551,7 +4552,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", aloha = ["gym-aloha"] dev = ["debugpy", "pre-commit"] dora = ["gym-dora"] -koch = ["dynamixel-sdk", "pynput"] +dynamixel = ["dynamixel-sdk", "pynput"] pusht = ["gym-pusht"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] @@ -4561,4 +4562,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a340f2ed23db2f3c371c494cbc9a33392e122ed6713e6098277a87b3fb805f2b" +content-hash = "781e1ca86ed53f76d1b28066a91cc591630886f3a908a691a5aa26146793a02c" diff --git a/pyproject.toml b/pyproject.toml index 999a72030..85affc7cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dev = ["pre-commit", "debugpy"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] video_benchmark = ["scikit-image", "pandas"] -koch = ["dynamixel-sdk", "pynput"] +dynamixel = ["dynamixel-sdk", "pynput"] [tool.ruff] line-length = 110 diff --git a/tests/conftest.py b/tests/conftest.py index eaf1b476c..52006f331 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,28 +13,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + import pytest from lerobot.common.utils.utils import init_hydra_config -from .utils import DEVICE, KOCH_ROBOT_CONFIG_PATH +from .utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE def pytest_collection_finish(): print(f"\nTesting with {DEVICE=}") -@pytest.fixture(scope="session") -def is_koch_available(): +@pytest.fixture +def is_robot_available(robot_type): try: from lerobot.common.robot_devices.robots.factory import make_robot - robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH) + config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type) + robot_cfg = init_hydra_config(config_path) robot = make_robot(robot_cfg) robot.connect() del robot return True - except Exception as e: - print("A koch robot is not available.") - print(e) + except Exception: + traceback.print_exc() + print(f"\nA {robot_type} robot is not available.") return False diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 9780a50ea..0d5d94425 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -1,9 +1,19 @@ +""" +Tests meant to be used locally and launched manually. + +Example usage: +```bash +pytest -sx tests/test_cameras.py::test_camera +``` +""" + import numpy as np import pytest +from lerobot import available_robots from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera, save_images_from_cameras from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from tests.utils import require_koch +from tests.utils import require_robot CAMERA_INDEX = 2 # Maximum absolute difference between two consecutive images recored by a camera. @@ -15,8 +25,9 @@ def compute_max_pixel_difference(first_image, second_image): return np.abs(first_image.astype(float) - second_image.astype(float)).max() -@require_koch -def test_camera(request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_camera(request, robot_type): """Test assumes that `camera.read()` returns the same image when called multiple times in a row. So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving. @@ -120,6 +131,7 @@ def test_camera(request): del camera -@require_koch -def test_save_images_from_cameras(tmpdir, request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_save_images_from_cameras(tmpdir, request, robot_type): save_images_from_cameras(tmpdir, record_time_s=1) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 5dae28e4d..406edeb4f 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -1,53 +1,53 @@ from pathlib import Path +import pytest + +from lerobot import available_robots from lerobot.common.policies.factory import make_policy -from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate -from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, KOCH_ROBOT_CONFIG_PATH, require_koch - - -def make_robot_(overrides=None): - robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH, overrides) - robot = make_robot(robot_cfg) - return robot +from tests.test_robots import make_robot +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_robot -@require_koch -# `require_koch` uses `request` to access `is_koch_available` fixture -def test_teleoperate(request): - robot = make_robot_() +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_teleoperate(request, robot_type): + robot = make_robot(robot_type) teleoperate(robot, teleop_time_s=1) teleoperate(robot, fps=30, teleop_time_s=1) teleoperate(robot, fps=60, teleop_time_s=1) del robot -@require_koch -def test_calibrate(request): - robot = make_robot_() +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_calibrate(request, robot_type): + robot = make_robot(robot_type) calibrate(robot) del robot -@require_koch -def test_record_without_cameras(tmpdir, request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_record_without_cameras(tmpdir, request, robot_type): root = Path(tmpdir) repo_id = "lerobot/debug" - robot = make_robot_(overrides=["~cameras"]) + robot = make_robot(robot_type, overrides=["~cameras"]) record(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2) -@require_koch -def test_record_and_replay_and_policy(tmpdir, request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_record_and_replay_and_policy(tmpdir, request, robot_type): env_name = "koch_real" policy_name = "act_koch_real" root = Path(tmpdir) repo_id = "lerobot/debug" - robot = make_robot_() + robot = make_robot(robot_type) dataset = record( robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2 ) diff --git a/tests/test_motors.py b/tests/test_motors.py index db9ca1f96..48c2e8d8d 100644 --- a/tests/test_motors.py +++ b/tests/test_motors.py @@ -1,3 +1,13 @@ +""" +Tests meant to be used locally and launched manually. + +Example usage: +```bash +pytest -sx tests/test_motors.py::test_find_port +pytest -sx tests/test_motors.py::test_motors_bus +``` +""" + # TODO(rcadene): measure fps in nightly? # TODO(rcadene): test logs # TODO(rcadene): test calibration @@ -5,34 +15,41 @@ import time -import hydra import numpy as np import pytest +from lerobot import available_robots +from lerobot.common.robot_devices.motors.utils import MotorsBus +from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.utils.utils import init_hydra_config -from tests.utils import KOCH_ROBOT_CONFIG_PATH, require_koch +from tests.utils import ROBOT_CONFIG_PATH_TEMPLATE, require_robot -def make_motors_bus(): - robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH) - # Instantiating a common motors structure. - # Here the one from Alexander Koch follower arm. - motors_bus = hydra.utils.instantiate(robot_cfg.leader_arms.main) +def make_motors_bus(robot_type: str) -> MotorsBus: + # Instantiate a robot and return one of its leader arms + config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type) + robot_cfg = init_hydra_config(config_path) + robot = make_robot(robot_cfg) + first_bus_name = list(robot.leader_arms.keys())[0] + motors_bus = robot.leader_arms[first_bus_name] return motors_bus -@require_koch -def test_find_port(request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_find_port(request, robot_type): from lerobot.common.robot_devices.motors.dynamixel import find_port find_port() -@require_koch -def test_configure_motors_all_ids_1(request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_configure_motors_all_ids_1(request, robot_type): + input("Are you sure you want to re-configure the motors? Press enter to continue...") # This test expect the configuration was already correct. - motors_bus = make_motors_bus() + motors_bus = make_motors_bus(robot_type) motors_bus.connect() motors_bus.write("Baud_Rate", [0] * len(motors_bus.motors)) motors_bus.set_bus_baudrate(9_600) @@ -40,15 +57,16 @@ def test_configure_motors_all_ids_1(request): del motors_bus # Test configure - motors_bus = make_motors_bus() + motors_bus = make_motors_bus(robot_type) motors_bus.connect() assert motors_bus.are_motors_configured() del motors_bus -@require_koch -def test_motors_bus(request): - motors_bus = make_motors_bus() +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_motors_bus(request, robot_type): + motors_bus = make_motors_bus(robot_type) # Test reading and writting before connecting raises an error with pytest.raises(RobotDeviceNotConnectedError): @@ -62,7 +80,7 @@ def test_motors_bus(request): del motors_bus # Test connecting - motors_bus = make_motors_bus() + motors_bus = make_motors_bus(robot_type) motors_bus.connect() # Test connecting twice raises an error diff --git a/tests/test_robots.py b/tests/test_robots.py index 6827c7e00..4ce3805ee 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -1,54 +1,52 @@ -import pickle +""" +Tests meant to be used locally and launched manually. + +Example usage: +```bash +pytest -sx tests/test_robots.py::test_robot +``` +""" + from pathlib import Path import pytest import torch -from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot import available_robots +from lerobot.common.robot_devices.robots.factory import make_robot as make_robot_from_cfg +from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from tests.utils import require_koch +from lerobot.common.utils.utils import init_hydra_config +from tests.utils import ROBOT_CONFIG_PATH_TEMPLATE, require_robot + + +def make_robot(robot_type: str, overrides: list[str] | None = None) -> Robot: + config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type) + robot_cfg = init_hydra_config(config_path, overrides) + robot = make_robot_from_cfg(robot_cfg) + return robot -@require_koch -def test_robot(tmpdir, request): +@pytest.mark.parametrize("robot_type", available_robots) +@require_robot +def test_robot(tmpdir, request, robot_type): # TODO(rcadene): measure fps in nightly? # TODO(rcadene): test logs # TODO(rcadene): add compatibility with other robots - from lerobot.common.robot_devices.robots.koch import KochRobot + from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot # Save calibration preset - calibration = { - "follower_main": { - "shoulder_pan": (-2048, False), - "shoulder_lift": (2048, True), - "elbow_flex": (-1024, False), - "wrist_flex": (2048, True), - "wrist_roll": (2048, True), - "gripper": (2048, True), - }, - "leader_main": { - "shoulder_pan": (-2048, False), - "shoulder_lift": (1024, True), - "elbow_flex": (2048, True), - "wrist_flex": (-2048, False), - "wrist_roll": (2048, True), - "gripper": (2048, True), - }, - } tmpdir = Path(tmpdir) - calibration_path = tmpdir / "calibration.pkl" - calibration_path.parent.mkdir(parents=True, exist_ok=True) - with open(calibration_path, "wb") as f: - pickle.dump(calibration, f) + calibration_dir = tmpdir / robot_type # Test connecting without devices raises an error - robot = KochRobot() + robot = ManipulatorRobot() with pytest.raises(ValueError): robot.connect() del robot # Test using robot before connecting raises an error - robot = KochRobot() + robot = ManipulatorRobot() with pytest.raises(RobotDeviceNotConnectedError): robot.teleop_step() with pytest.raises(RobotDeviceNotConnectedError): @@ -64,9 +62,7 @@ def test_robot(tmpdir, request): del robot # Test connecting - robot = make_robot("koch") - # TODO(rcadene): proper monkey patch - robot.calibration_path = calibration_path + robot = make_robot(robot_type, overrides=[f"calibration_dir={calibration_dir}"]) robot.connect() # run the manual calibration precedure assert robot.is_connected @@ -78,8 +74,8 @@ def test_robot(tmpdir, request): del robot # Test teleop can run - robot = make_robot("koch") - robot.calibration_path = calibration_path + robot = make_robot(robot_type, overrides=[f"calibration_dir={calibration_dir}"]) + robot.calibration_dir = calibration_dir robot.connect() robot.teleop_step() diff --git a/tests/utils.py b/tests/utils.py index f79ad2495..db214aeac 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,11 +21,12 @@ from lerobot.common.utils.import_utils import is_package_available +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + # Pass this as the first argument to init_hydra_config. DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" -KOCH_ROBOT_CONFIG_PATH = "lerobot/configs/robot/koch.yaml" -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +ROBOT_CONFIG_PATH_TEMPLATE = "lerobot/configs/robot/{robot}.yaml" def require_x86_64_kernel(func): @@ -150,21 +151,35 @@ def wrapper(*args, **kwargs): return decorator -def require_koch(func): +def require_robot(func): """ - Decorator that skips the test if an alexander koch robot is not available + Decorator that skips the test if a robot is not available + + The decorated function must have two arguments `request` and `robot_type`. + + Example of usage: + ```python + @pytest.mark.parametrize( + "robot_type", ["koch", "aloha"] + ) + @require_robot + def test_require_robot(request, robot_type): + pass + ``` """ @wraps(func) def wrapper(*args, **kwargs): - # Access the pytest request context to get the is_koch_available fixture + # Access the pytest request context to get the is_robot_available fixture request = kwargs.get("request") + robot_type = kwargs.get("robot_type") + if request is None: raise ValueError("The 'request' fixture must be passed to the test function as a parameter.") - # The function `is_koch_available` is defined in `tests/conftest.py` - if not request.getfixturevalue("is_koch_available"): - pytest.skip("An alexander koch robot is not available.") + # The function `is_robot_available` is defined in `tests/conftest.py` + if not request.getfixturevalue("is_robot_available"): + pytest.skip(f"A {robot_type} robot is not available.") return func(*args, **kwargs) return wrapper