@@ -25,6 +25,36 @@ def observable_indices_in_tensor(
2525 return sorted_obs_dict
2626
2727
28+ def wing_qpos_to_conventional (model_wing_qpos : np .ndarray ,
29+ body_pitch_angle : float = 47.5 ,
30+ ) -> np .ndarray :
31+ """Transform model wing joint qpos to conventional wing kinematics definition.
32+
33+ Args:
34+ model_wing_qpos: Wing MjData.qpos in radians, shape (B, 6).
35+ Order of joints: yaw, roll, pitch, yaw, roll, pitch.
36+ Left-right order is arbitrary.
37+ body_pitch_angle: Body pitch angle for initial flight pose, relative to
38+ ground, degrees. 0: horizontal body position. Default value from
39+ https://doi.org/10.1126/science.1248955
40+
41+ Returns:
42+ Wing angles transformed to conventional representation.
43+ """
44+ if not isinstance (model_wing_qpos , np .ndarray ):
45+ model_wing_qpos = np .array (model_wing_qpos )
46+ conventional = np .zeros_like (model_wing_qpos )
47+ body_pitch_angle = np .deg2rad (body_pitch_angle )
48+ # Yaw, doesn't require transformation.
49+ conventional [..., [0 , 3 ]] = model_wing_qpos [..., [0 , 3 ]].copy ()
50+ # Roll.
51+ conventional [..., [1 , 4 ]] = - model_wing_qpos [..., [1 , 4 ]]
52+ # Pitch.
53+ conventional [..., [2 , 5 ]] = (
54+ np .pi / 2 - body_pitch_angle - model_wing_qpos [..., [2 , 5 ]])
55+ return conventional
56+
57+
2858def get_random_policy (action_spec : 'dm_env.specs.BoundedArray' ,
2959 minimum : float = - 0.2 ,
3060 maximum : float = 0.2 ) -> Callable [[Any ], np .ndarray ]:
0 commit comments