Skip to content

Commit eb2e117

Browse files
authored
Merge pull request #76 from AllenNeuralDynamics/wip/bundle-adjustment
Wip/bundle adjustment
2 parents f8e1ce5 + 9b827b3 commit eb2e117

19 files changed

+1022
-404
lines changed

parallax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66

7-
__version__ = "0.37.18"
7+
__version__ = "0.37.19"
88

99
# allow multiple OpenMP instances
1010
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

parallax/__main__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,28 @@ def setup_logging():
4949
action="store_true",
5050
help="Dummy mode for testing without hardware",
5151
)
52+
53+
parser.add_argument(
54+
"-b",
55+
"--bundle_adjustment",
56+
action="store_true",
57+
help="Enable bundle adjustment feature",
58+
)
5259
args = parser.parse_args()
5360

5461
# Print a message if running in dummy mode (no hardware interaction)
5562
if args.dummy:
5663
print("\nRunning in dummy mode; hardware devices not accessible.")
64+
# Print a message if bundle adjustment is enabled
65+
if args.bundle_adjustment:
66+
print("\nBundle adjustment feature enabled.")
5767

5868
# Set up logging as configured in the setup_logging function
5969
setup_logging()
6070

6171
# Initialize the Qt application
6272
app = QApplication([])
63-
model = Model(version="V2") # Initialize the data model with version "V2"
73+
model = Model(version="V2", bundle_adjustment=args.bundle_adjustment) # Initialize the data model with version "V2"
6474
main_window = MainWindowV2(model, dummy=args.dummy) # main window
6575

6676
# Show the main window on screen

parallax/axis_filter.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import cv2
1111
import numpy as np
1212
from PyQt5.QtCore import QObject, QThread, pyqtSignal
13+
from .calibration_camera import CalibrationCamera
1314

1415
# Set logger name
1516
logger = logging.getLogger(__name__)
@@ -20,12 +21,16 @@ class AxisFilter(QObject):
2021

2122
name = "None"
2223
frame_processed = pyqtSignal(object)
24+
found_coords = pyqtSignal(np.ndarray, np.ndarray, np.ndarray, np.ndarray, tuple, tuple)
2325

2426
class Worker(QObject):
2527
"""Worker class for processing frames in a separate thread."""
2628

2729
finished = pyqtSignal()
2830
frame_processed = pyqtSignal(object)
31+
found_coords = pyqtSignal(
32+
np.ndarray, np.ndarray, np.ndarray, np.ndarray, tuple, tuple
33+
)
2934

3035
def __init__(self, name, model):
3136
"""Initialize the worker object."""
@@ -37,6 +42,7 @@ def __init__(self, name, model):
3742
self.frame = None
3843
self.reticle_coords = self.model.get_coords_axis(self.name)
3944
self.pos_x = None
45+
self.calibrationCamera = CalibrationCamera(self.name)
4046

4147
def update_frame(self, frame):
4248
"""Update the frame to be processed.
@@ -90,9 +96,13 @@ def sort_reticle_points(self):
9096

9197
def clicked_position(self, input_pt):
9298
"""Get clicked position."""
99+
if not self.running:
100+
return
101+
93102
if self.reticle_coords is None:
94103
return
95104

105+
logger.debug(f"clicked_position {input_pt}")
96106
# Coordinates of points
97107
pt1, pt2 = self.reticle_coords[0][0], self.reticle_coords[0][-1]
98108
pt3, pt4 = self.reticle_coords[1][0], self.reticle_coords[1][-1]
@@ -104,12 +114,23 @@ def clicked_position(self, input_pt):
104114

105115
# sort the reticle points and register to the model
106116
self.sort_reticle_points()
117+
ret, mtx, dist, rvecs, tvecs = self.calibrationCamera.calibrate_camera(
118+
self.reticle_coords[0], self.reticle_coords[1]
119+
)
120+
if ret:
121+
self.found_coords.emit(
122+
self.reticle_coords[0], self.reticle_coords[1], mtx, dist, rvecs, tvecs
123+
)
124+
125+
# Register the camera intrinsic parameters and coords to the model
107126
self.model.add_coords_axis(self.name, self.reticle_coords)
127+
self.model.add_camera_intrinsic(self.name, mtx, dist, rvecs, tvecs)
108128
self.model.add_pos_x(self.name, self.pos_x)
109129

110130
def reset_pos_x(self):
111131
self.pos_x = None
112132
self.model.reset_pos_x()
133+
logger.debug("reset pos_x")
113134

114135
def stop_running(self):
115136
"""Stop the worker from running."""
@@ -153,14 +174,15 @@ def init_thread(self):
153174
self.worker.moveToThread(self.thread)
154175

155176
self.thread.started.connect(self.worker.run)
156-
self.worker.finished.connect(self.worker.deleteLater)
177+
self.thread.finished.connect(self.thread.deleteLater)
157178
self.thread.destroyed.connect(self.onThreadDestroyed)
158179
self.threadDeleted = False
159180

160181
#self.worker.frame_processed.connect(self.frame_processed)
161182
self.worker.frame_processed.connect(self.frame_processed.emit)
183+
self.worker.found_coords.connect(self.found_coords)
162184
self.worker.finished.connect(self.thread.quit)
163-
self.worker.finished.connect(self.thread.deleteLater)
185+
self.worker.finished.connect(self.worker.deleteLater)
164186
self.worker.destroyed.connect(self.onWorkerDestroyed)
165187
logger.debug(f"init camera name: {self.name}")
166188

@@ -184,8 +206,8 @@ def stop(self):
184206
"""Stop the filter by stopping the worker."""
185207
logger.debug(f" {self.name} Stopping thread")
186208
if self.worker is not None:
187-
self.worker.stop_running()
188209
self.worker.reset_pos_x()
210+
self.worker.stop_running()
189211

190212
def onWorkerDestroyed(self):
191213
"""Cleanup after worker finishes."""

parallax/bundle_adjustment.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import numpy as np
2+
import pandas as pd
3+
import cv2
4+
import logging
5+
from scipy.optimize import leastsq
6+
7+
# Set logger name
8+
logger = logging.getLogger(__name__)
9+
logger.setLevel(logging.DEBUG)
10+
11+
class BALProblem:
12+
def __init__(self, model, file_path):
13+
self.list_cameras = None
14+
self.observations = None
15+
self.points = None
16+
self.cameras_params = None
17+
self.local_pts = None
18+
19+
self.model = model
20+
self.df = None
21+
self.file_path = file_path
22+
self._parse_csv()
23+
self._set_camera_params()
24+
25+
def _parse_csv(self):
26+
self.df = pd.read_csv(self.file_path)
27+
#self._remove_duplicates()
28+
self._average_3D_points()
29+
30+
self._set_camera_list()
31+
self._set_points()
32+
self._set_observations()
33+
34+
def _set_camera_list(self):
35+
cameras = pd.concat([self.df['cam0'], self.df['cam1']]).unique()
36+
self.list_cameras = [str(camera) for camera in cameras]
37+
38+
def _set_points(self):
39+
unique_df = self.df.drop_duplicates(subset=['m_global_x', 'm_global_y', 'm_global_z'])
40+
self.points = np.array(unique_df[['m_global_x', 'm_global_y', 'm_global_z']].values)
41+
self.local_pts = np.array(unique_df[['local_x', 'local_y', 'local_z']].values)
42+
43+
def _set_observations(self):
44+
# Initialize the list to store observations
45+
self.observations = []
46+
47+
# Create a mapping from camera IDs to indices
48+
camera_id_to_index = {str(camera_id): idx for idx, camera_id in enumerate(self.list_cameras)}
49+
50+
# Iterate through the DataFrame to collect observations
51+
for _, row in self.df.iterrows():
52+
cam0, pt0 = str(row['cam0']), row['pt0']
53+
cam1, pt1 = str(row['cam1']), row['pt1']
54+
55+
# Find the point index corresponding to the average global coordinates
56+
m_global_x, m_global_y, m_global_z = row['m_global_x'], row['m_global_y'], row['m_global_z']
57+
point_index = np.where((self.points[:, 0] == m_global_x) &
58+
(self.points[:, 1] == m_global_y) &
59+
(self.points[:, 2] == m_global_z))[0][0]
60+
61+
# Add observations for cam0
62+
if pd.notna(pt0):
63+
pt0_coords = np.array(list(map(float, pt0.strip('()').split(','))))
64+
camera_index = camera_id_to_index[cam0]
65+
self.observations.append([camera_index, point_index, pt0_coords[0], pt0_coords[1]])
66+
67+
# Add observations for cam1
68+
if pd.notna(pt1):
69+
pt1_coords = np.array(list(map(float, pt1.strip('()').split(','))))
70+
camera_index = camera_id_to_index[cam1]
71+
self.observations.append([camera_index, point_index, pt1_coords[0], pt1_coords[1]])
72+
73+
# Convert the observations list to a numpy array
74+
self.observations = np.array(self.observations)
75+
76+
def _average_3D_points(self):
77+
# Group by 'ts_local_coords' and calculate the mean for 'global_x', 'global_y', and 'global_z'
78+
grouped = self.df.groupby('ts_local_coords')[['global_x', 'global_y', 'global_z']].mean()
79+
grouped = grouped.rename(columns={'global_x': 'm_global_x', 'global_y': 'm_global_y', 'global_z': 'm_global_z'})
80+
81+
# Merge the averaged columns back into the original DataFrame
82+
self.df = self.df.merge(grouped, on='ts_local_coords', how='left')
83+
84+
# Create a mapping of ts_local_coords to index in the averaged points
85+
self.df['point_index'] = self.df.groupby('ts_local_coords').ngroup()
86+
87+
# Write the updated DataFrame back to the CSV file
88+
self.df.to_csv(self.file_path, index=False)
89+
90+
91+
def _remove_duplicates(self):
92+
# Drop duplicate rows based on 'ts_local_coords', 'global_x', 'global_y', 'global_z' columns
93+
logger.debug(f"Original rows: {self.df.shape[0]}")
94+
self.df = self.df.drop_duplicates(subset=['ts_local_coords', 'global_x', 'global_y', 'global_z'])
95+
logger.debug(f"Unique rows: {self.df.shape[0]}")
96+
97+
def _set_camera_params(self):
98+
if not self.list_cameras:
99+
return
100+
101+
logger.debug(self.list_cameras)
102+
logger.debug(self.model.camera_intrinsic)
103+
104+
self.cameras_params = []
105+
106+
for camera_name in self.list_cameras:
107+
intrinsic = self.model.get_camera_intrinsic(camera_name)
108+
if intrinsic is None:
109+
logger.warning("Intrinsic parameters not found for camera %s", camera_name)
110+
continue
111+
112+
# intrinsic: [mtx, dist, rvec, tvec]
113+
mtx, dist, rvec, tvec = intrinsic[0], intrinsic[1], intrinsic[2][0], intrinsic[3][0]
114+
rvec = rvec.reshape(3, 1)
115+
tvec = tvec.reshape(3, 1)
116+
117+
f = mtx[0, 0]
118+
k1, k2, p1, p2, k3 = dist.ravel()
119+
R = rvec.ravel()
120+
t = tvec.ravel()
121+
122+
camera_param = np.array([
123+
R[0], R[1], R[2], # Rotation
124+
t[0], t[1], t[2], # Translation
125+
f, k1, k2, p1, p2, k3 # Intrinsics
126+
], dtype=np.float64)
127+
self.cameras_params.append(camera_param)
128+
129+
def get_camera_params(self, i):
130+
return self.cameras_params[i]
131+
132+
def get_point(self, i):
133+
return self.points[i]
134+
135+
class BALOptimizer:
136+
def __init__(self, bal_problem):
137+
self.bal_problem = bal_problem
138+
self.opt_camera_params = None
139+
self.opt_points = None
140+
141+
def residuals(self, params):
142+
residuals = []
143+
n_cams = len(self.bal_problem.list_cameras)
144+
n_pts = len(self.bal_problem.points)
145+
camera_params = params[:12 * n_cams].reshape(n_cams, 12)
146+
points = params[12 * n_cams:].reshape(n_pts, 3)
147+
148+
for obs in self.bal_problem.observations:
149+
cam_idx, pt_idx, observed_x, observed_y = int(obs[0]), int(obs[1]), obs[2], obs[3]
150+
camera = camera_params[cam_idx]
151+
point = points[pt_idx]
152+
153+
point = point / 1000
154+
rvec = np.array(camera[:3])
155+
tvec = np.array(camera[3:6])
156+
focal = camera[6]
157+
mtx = np.array([[focal, 0.0, 2000.0],
158+
[0.0, focal, 1500.0],
159+
[0.0, 0.0, 1.0]], dtype=np.float32)
160+
161+
k1, k2, p1, p2, k3 = camera[7:12]
162+
dist = np.array([k1, k2, p1, p2, k3], dtype=np.float32)
163+
164+
imgpts, _ = cv2.projectPoints(point.reshape(1, 3), rvec, tvec, mtx, dist)
165+
predicted_x = imgpts[0][0][0]
166+
predicted_y = imgpts[0][0][1]
167+
168+
residuals.append(predicted_x - observed_x)
169+
residuals.append(predicted_y - observed_y)
170+
171+
return np.array(residuals)
172+
173+
def optimize(self, print_result=True):
174+
# Initial parameters vector
175+
initial_params = np.hstack([param.ravel() for param in self.bal_problem.cameras_params] + [self.bal_problem.points.ravel()])
176+
177+
# Perform optimization using leastsq
178+
result = leastsq(self.residuals, initial_params, full_output=True)
179+
opt_params = result[0]
180+
181+
# Extract Optimize camera parameters and points
182+
n_cams = len(self.bal_problem.list_cameras)
183+
n_pts = len(self.bal_problem.points)
184+
self.opt_camera_params = opt_params[:12 * n_cams].reshape(n_cams, 12)
185+
self.opt_points = opt_params[12 * n_cams:].reshape(n_pts, 3)
186+
187+
if print_result:
188+
print(f"\n************** Optimization completed. **************************")
189+
# Compute initial residuals
190+
initial_residuals = self.residuals(initial_params)
191+
initial_residuals_sum = np.sum(initial_residuals**2)
192+
average_residual = initial_residuals_sum / len(self.bal_problem.observations)
193+
print(f"** Before BA, Average residual of reproj: {np.round(average_residual, 2)} **")
194+
195+
# Compute Optimize residuals
196+
opt_residuals = self.residuals(opt_params)
197+
opt_residuals_sum = np.sum(opt_residuals**2)
198+
average_residual = opt_residuals_sum / len(self.bal_problem.observations)
199+
print(f"** After BA, Average residual of reproj: {np.round(average_residual, 2)} **")
200+
print(f"******************************************************************")
201+
202+
logger.debug(f"Optimized camera parameters: {self.opt_camera_params}")
203+
204+
for i in range(len(self.bal_problem.points)):
205+
logger.debug(f"\nPoint {i}")
206+
logger.debug(f"org : {self.bal_problem.points[i]}")
207+
logger.debug(f"opt : {self.opt_points[i]}")
208+
209+
# Map optimized points to the original DataFrame rows
210+
opt_points_df = pd.DataFrame(self.opt_points, columns=['opt_x', 'opt_y', 'opt_z'])
211+
self.bal_problem.df = self.bal_problem.df.join(opt_points_df, on='point_index', rsuffix='_opt')
212+
213+
# Save the updated DataFrame to the CSV file
214+
self.bal_problem.df.to_csv(self.bal_problem.file_path, index=False)
215+
logger.info(f"Optimized points saved to {self.bal_problem.file_path}")

0 commit comments

Comments
 (0)