diff --git a/flybody/tasks/task_utils.py b/flybody/tasks/task_utils.py index f7e02894..d846360b 100755 --- a/flybody/tasks/task_utils.py +++ b/flybody/tasks/task_utils.py @@ -25,6 +25,36 @@ def observable_indices_in_tensor( return sorted_obs_dict +def wing_qpos_to_conventional(model_wing_qpos: np.ndarray, + body_pitch_angle: float = 47.5, + ) -> np.ndarray: + """Transform model wing joint qpos to conventional wing kinematics definition. + + Args: + model_wing_qpos: Wing MjData.qpos in radians, shape (B, 6). + Order of joints: yaw, roll, pitch, yaw, roll, pitch. + Left-right order is arbitrary. + body_pitch_angle: Body pitch angle for initial flight pose, relative to + ground, degrees. 0: horizontal body position. Default value from + https://doi.org/10.1126/science.1248955 + + Returns: + Wing angles transformed to conventional representation. + """ + if not isinstance(model_wing_qpos, np.ndarray): + model_wing_qpos = np.array(model_wing_qpos) + conventional = np.zeros_like(model_wing_qpos) + body_pitch_angle = np.deg2rad(body_pitch_angle) + # Yaw, doesn't require transformation. + conventional[..., [0, 3]] = model_wing_qpos[..., [0, 3]].copy() + # Roll. + conventional[..., [1, 4]] = - model_wing_qpos[..., [1, 4]] + # Pitch. + conventional[..., [2, 5]] = ( + np.pi / 2 - body_pitch_angle - model_wing_qpos[..., [2, 5]]) + return conventional + + def get_random_policy(action_spec: 'dm_env.specs.BoundedArray', minimum: float = -0.2, maximum: float = 0.2) -> Callable[[Any], np.ndarray]: