Skip to content

Commit 270d80c

Browse files
committed
Add script for computing metrics between simulated and real trajectories
1 parent 78220bb commit 270d80c

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

scripts/compute_real2sim_metrics.py

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Script for computing error metrics between simulated and real trajectories."""
2+
3+
import argparse
4+
import csv
5+
6+
from typing import Dict, Union
7+
8+
import numpy as np
9+
10+
from sim2sim.util import (
11+
average_displacement_error_translation_only,
12+
final_displacement_error_translation_only,
13+
orientation_considered_average_displacement_error,
14+
orientation_considered_final_displacement_error,
15+
trajectory_IoU,
16+
)
17+
18+
19+
def create_metric_dict(
20+
name: str,
21+
trajectory_IoU_margin_01: float,
22+
trajectory_IoU_margin_1: float,
23+
orientation_considered_final_error: float,
24+
orientation_considered_average_error: float,
25+
orientation_considered_average_error_10points: float,
26+
final_translation_error: float,
27+
average_translation_error: float,
28+
) -> Dict[str, Union[str, float]]:
29+
return {
30+
"name": name,
31+
"trajectory_IoU_margin_01": trajectory_IoU_margin_01,
32+
"trajectory_IoU_margin_1": trajectory_IoU_margin_1,
33+
"orientation_considered_final_error": orientation_considered_final_error,
34+
"orientation_considered_average_error": orientation_considered_average_error,
35+
"orientation_considered_average_error_10points": orientation_considered_average_error_10points,
36+
"final_translation_error": final_translation_error,
37+
"average_translation_error": average_translation_error,
38+
}
39+
40+
41+
def main():
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument(
44+
"--name",
45+
required=True,
46+
type=str,
47+
help="The name of the experiment.",
48+
)
49+
parser.add_argument(
50+
"--real_poses_path",
51+
required=True,
52+
type=str,
53+
help="The path to the numpy file containing the real poses. The poses should be "
54+
+ "of shape (M, 7) where M is the number of objects.",
55+
)
56+
parser.add_argument(
57+
"--sim_states_path",
58+
required=True,
59+
type=str,
60+
help="The path to the numpy file containing the sim states. The states should "
61+
+ "be of shape (M, 13) where M is the number of objects.",
62+
)
63+
parser.add_argument(
64+
"--num_trajectory_iou_samples",
65+
default=0,
66+
type=int,
67+
help="The number of samples to use for computing trajectory IoU.",
68+
)
69+
parser.add_argument(
70+
"--out_path",
71+
required=True,
72+
type=str,
73+
help="The path to the output csv file.",
74+
)
75+
args = parser.parse_args()
76+
name = args.name
77+
real_poses_path = args.real_poses_path
78+
sim_states_path = args.sim_states_path
79+
num_trajectory_iou_samples = args.num_trajectory_iou_samples
80+
out_path = args.out_path
81+
82+
real_poses = np.load(real_poses_path) # Shape (M,7)
83+
sim_states = np.load(sim_states_path) # Shape (M,13)
84+
85+
# Set velocities of real data to NaN
86+
real_states = np.concatenate(
87+
(real_poses, np.full((len(real_poses), 6), np.nan)), axis=1
88+
)
89+
90+
# Compute metric for each manipuland separately
91+
trajectory_IoU_margins_01 = [
92+
trajectory_IoU(
93+
real,
94+
sim,
95+
margin=0.1,
96+
num_samples=num_trajectory_iou_samples,
97+
)
98+
for real, sim in zip(real_states, sim_states)
99+
] # Shape (M,)
100+
trajectory_IoU_margins_1 = [
101+
trajectory_IoU(
102+
real,
103+
sim,
104+
margin=1.0,
105+
num_samples=num_trajectory_iou_samples,
106+
)
107+
for real, sim in zip(real_states, sim_states)
108+
] # Shape (M,)
109+
orientation_considered_final_errors = [
110+
orientation_considered_final_displacement_error(real, sim)
111+
for real, sim in zip(real_states, sim_states)
112+
] # Shape (M,)
113+
orientation_considered_average_errors = [
114+
orientation_considered_average_displacement_error(real, sim)
115+
for real, sim in zip(real_states, sim_states)
116+
] # Shape (M,)
117+
orientation_considered_average_errors_10points = [
118+
orientation_considered_average_displacement_error(real, sim, 10)
119+
for real, sim in zip(real_states, sim_states)
120+
] # Shape (M,)
121+
final_displacement_errors_translation_only = [
122+
final_displacement_error_translation_only(real, sim)
123+
for real, sim in zip(real_states, sim_states)
124+
] # Shape (M,)
125+
average_displacement_errors_translation_only = [
126+
average_displacement_error_translation_only(real, sim)
127+
for real, sim in zip(real_states, sim_states)
128+
] # Shape (M,)
129+
130+
# Average metrics over manipulands
131+
metrics = create_metric_dict(
132+
name=name,
133+
trajectory_IoU_margin_01=np.mean(trajectory_IoU_margins_01),
134+
trajectory_IoU_margin_1=np.mean(trajectory_IoU_margins_1),
135+
orientation_considered_final_error=np.mean(orientation_considered_final_errors),
136+
orientation_considered_average_error=np.mean(
137+
orientation_considered_average_errors
138+
),
139+
orientation_considered_average_error_10points=np.mean(
140+
orientation_considered_average_errors_10points
141+
),
142+
final_translation_error=np.mean(final_displacement_errors_translation_only),
143+
average_translation_error=np.mean(average_displacement_errors_translation_only),
144+
)
145+
146+
with open(out_path, "w", newline="") as f:
147+
writer = csv.writer(f)
148+
column_names = metrics.keys()
149+
writer.writerow(column_names)
150+
writer.writerow([metrics[name] for name in column_names])
151+
152+
153+
if __name__ == "__main__":
154+
main()

0 commit comments

Comments
 (0)