Skip to content

Commit e962a1e

Browse files
committed
Fix legend and annotation placement in fitting & merging plots
1 parent b014ad3 commit e962a1e

9 files changed

Lines changed: 160 additions & 29 deletions

interface_DBA_dye_to_host_fitting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import tkinter as tk
77
from tkinter import filedialog
88
from pltstyle import create_plots
9+
from matplotlib.transforms import Bbox
10+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
911

1012
def run_dba_dye_to_host_fitting(file_path, results_dir, h0_in_M, rmse_threshold_factor, r2_threshold, save_plots, display_plots, plots_dir, save_results, results_save_dir, number_of_fit_trials):
1113
# Convert initial concentration to µM units
@@ -224,7 +226,7 @@ def calculate_fit_metrics(Signal_observed, Signal_computed):
224226
ax.plot(d0_values, Signal_observed, 'o', label='Observed Signal')
225227
ax.plot(fitting_curve_x, fitting_curve_y, '--', color='blue', alpha=0.6, label='Simulated Fitting Curve')
226228
ax.set_title(f'Observed vs. Simulated Fitting Curve for Replica {replica_index}')
227-
ax.legend(loc='best', bbox_to_anchor=(0.02, 0.98))
229+
ax.legend(loc='best')
228230

229231
# TODO: double check whether x 10^6 is needed for Id and Ihd
230232
# TODO: should Kd be multiplied by 10^6 or 10^-6?
@@ -235,8 +237,7 @@ def calculate_fit_metrics(Signal_observed, Signal_computed):
235237
f"$RMSE$: {rmse:.3f}\n"
236238
f"$R^2$: {r_squared:.3f}")
237239

238-
ax.annotate(param_text, xy=(0.8, 0.04), xycoords='axes fraction', fontsize=10,
239-
ha='left', va='bottom', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5))
240+
place_annotation_opposite_legend(ax, param_text)
240241

241242
if save_plots:
242243
plot_file = os.path.join(plots_dir, f"fit_plot_replica_{replica_index}.png")

interface_DBA_host_to_dye_fitting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import matplotlib.pyplot as plt
77
from datetime import datetime
88
from pltstyle import create_plots
9+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
10+
from matplotlib.transforms import Bbox
911

1012
# Add number_of_fit_trials to function parameters
1113
def run_dba_host_to_dye_fitting(file_path, results_dir, d0_in_M, rmse_threshold_factor, r2_threshold, save_plots, display_plots, plots_dir, save_results, results_save_dir, number_of_fit_trials):
@@ -229,7 +231,7 @@ def calculate_fit_metrics(Signal_observed, Signal_computed):
229231
ax.plot(h0_values, Signal_observed, 'o', label='Observed Signal')
230232
ax.plot(fitting_curve_x, fitting_curve_y, '--', color='blue', alpha=0.6, label='Simulated Fitting Curve')
231233
ax.set_title(f'Observed vs. Simulated Fitting Curve for Replica {replica_index}')
232-
ax.legend(loc='best', bbox_to_anchor=(0.02, 0.98))
234+
ax.legend(loc='best')
233235

234236
param_text = (f"$K_d$: {median_params[1] * 1e6:.2e} $M^{{-1}}$\n"
235237
f"$I_0$: {median_params[0]:.2e}\n"
@@ -238,8 +240,7 @@ def calculate_fit_metrics(Signal_observed, Signal_computed):
238240
f"$RMSE$: {rmse:.3f}\n"
239241
f"$R^2$: {r_squared:.3f}")
240242

241-
ax.annotate(param_text, xy=(0.8, 0.04), xycoords='axes fraction', fontsize=10,
242-
ha='left', va='bottom', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5))
243+
place_annotation_opposite_legend(ax, param_text)
243244

244245
if save_plots:
245246
plot_file = os.path.join(plots_dir, f"fit_plot_replica_{replica_index}.png")

interface_DyeAlone_fitting.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from scipy.stats import linregress, ttest_1samp, t
1010
import matplotlib.pyplot as plt
1111
from matplotlib.ticker import FuncFormatter
12+
from matplotlib.transforms import Bbox
1213

1314
# Local imports
1415
from pltstyle import create_plots
@@ -131,6 +132,52 @@ def round_to_sigfigs(value, sigfigs=4):
131132
return f"{value:.{sigfigs}g}"
132133
return value
133134

135+
def place_annotation_safely(ax, text, fontsize=10, margin=10, **kwargs):
136+
"""
137+
Place annotation in a corner that avoids overlap with legend and data points.
138+
Tries all corners, prefers the first non-overlapping, else picks the one with least overlap.
139+
"""
140+
renderer = ax.figure.canvas.get_renderer()
141+
legend = ax.get_legend()
142+
legend_bbox = legend.get_window_extent(renderer) if legend else None
143+
# Get data points in display coords
144+
data_disp = ax.transData.transform(np.column_stack((ax.get_lines()[0].get_xdata(), ax.get_lines()[0].get_ydata())))
145+
data_bboxes = [Bbox.from_bounds(x - margin, y - margin, 2*margin, 2*margin) for x, y in data_disp]
146+
# Candidate positions: (x, y, ha, va)
147+
candidates = [
148+
(0.01, 0.99, 'left', 'top'), # upper left
149+
(0.99, 0.99, 'right', 'top'), # upper right
150+
(0.01, 0.01, 'left', 'bottom'), # lower left
151+
(0.99, 0.01, 'right', 'bottom') # lower right
152+
]
153+
best_ann = None
154+
min_overlap = float('inf')
155+
for x, y, ha, va in candidates:
156+
ann = ax.annotate(
157+
text, xy=(x, y), xycoords='axes fraction',
158+
ha=ha, va=va, fontsize=fontsize,
159+
bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5),
160+
**kwargs
161+
)
162+
plt.draw() # Needed to get the bbox
163+
ann_bbox = ann.get_window_extent(renderer)
164+
overlap = 0
165+
# Check overlap with legend
166+
if legend_bbox is not None and ann_bbox.overlaps(legend_bbox):
167+
overlap += ann_bbox.intersection(legend_bbox).area
168+
# Check overlap with data points
169+
for db in data_bboxes:
170+
if ann_bbox.overlaps(db):
171+
overlap += ann_bbox.intersection(db).area
172+
if overlap == 0:
173+
return ann # Found a good spot
174+
if overlap < min_overlap:
175+
min_overlap = overlap
176+
best_ann = ann
177+
else:
178+
ann.remove()
179+
return best_ann # Return the least-overlapping position if all overlap
180+
134181
# Main function to perform the fitting and plotting
135182
def perform_fitting(input_file_path, output_file_path, save_plots, display_plots, plots_dir):
136183
if not output_file_path.endswith(".txt"):
@@ -150,7 +197,7 @@ def perform_fitting(input_file_path, output_file_path, save_plots, display_plots
150197
Id_mean, Id_lower_bound, Id_upper_bound, Id_stdev = prediction_interval(retained_slopes, avg_slope)
151198
I0_mean, I0_lower_bound, I0_upper_bound, I0_stdev = prediction_interval(retained_intercepts, avg_intercept)
152199

153-
if len(retained_slopes) == 1:
200+
if len(retained_slopes == 1):
154201
Id_mean = retained_slopes[0]
155202
I0_mean = retained_intercepts[0]
156203
Id_lower_bound = "not applicable"
@@ -204,7 +251,10 @@ def scientific_notation(val, pos=0):
204251
handles.append(avg_fit_plot[0])
205252
labels.append(rf'Average Fit: $Y = {formatter(avg_slope)}X + {formatter(avg_intercept)}$')
206253

207-
ax.legend(handles, labels, loc='upper left', bbox_to_anchor=(0, 1))
254+
ax.legend(handles, labels, loc='best')
255+
256+
param_text = f"Id: {Id_mean:.3e}\nI0: {I0_mean:.3e}"
257+
place_annotation_safely(ax, param_text)
208258

209259
if save_plots:
210260
# Create a unique plot filename based on the input file

interface_GDA_fitting.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import numpy as np
77
from scipy.optimize import brentq, minimize
88
import matplotlib.pyplot as plt
9+
from matplotlib.transforms import Bbox
910

1011
from pltstyle import create_plots
12+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
1113

1214
def format_value(value):
1315
return f"{value:.0f}" if value > 10 else f"{value:.2f}"
14-
16+
1517
# Function to run the fitting process
1618
def run_fitting(file_path, results_file_path, Kd_in_M, h0_in_M, g0_in_M, number_of_fit_trials, rmse_threshold_factor, r2_threshold, save_plots, display_plots, plots_dir, save_results, results_save_dir):
1719
try:
@@ -149,7 +151,7 @@ def run_fitting(file_path, results_file_path, Kd_in_M, h0_in_M, g0_in_M, number_
149151
ax.plot(d0_values, Signal_observed, 'o', label='Observed Signal')
150152
ax.plot(fitting_curve_x, fitting_curve_y, '--', color='blue', alpha=0.6, label='Simulated Fitting Curve')
151153
ax.set_title(f'Observed vs. Simulated Fitting Curve for Replica {replica_index}')
152-
ax.legend(loc='best', bbox_to_anchor=(0.02, 0.98))
154+
ax.legend(loc='best')
153155

154156
# Annotate plot with median parameter values and fit metrics
155157
param_text = (f"$K_g$: {median_params[1] * 1e6:.2e} $M^{{-1}}$\n"
@@ -159,8 +161,7 @@ def run_fitting(file_path, results_file_path, Kd_in_M, h0_in_M, g0_in_M, number_
159161
f"$RMSE$: {format_value(rmse)}\n"
160162
f"$R^2$: {r_squared:.3f}")
161163

162-
ax.annotate(param_text, xy=(0.8, 0.04), xycoords='axes fraction', fontsize=10,
163-
ha='left', va='bottom', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5))
164+
place_annotation_opposite_legend(ax, param_text)
164165

165166
if save_plots:
166167
plot_file = os.path.join(plots_dir, f"fit_plot_replica_{replica_index}.png")

interface_IDA_fitting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import matplotlib.pyplot as plt
77
from scipy.optimize import brentq, minimize
88
from pltstyle import create_plots
9-
9+
from matplotlib.transforms import Bbox
10+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
11+
1012
def run_ida_fitting(file_path, results_file_path, Kd_in_M, h0_in_M, g0_in_M, number_of_fit_trials, rmse_threshold_factor, r2_threshold, save_plots, display_plots, plots_dir, save_results, results_save_dir):
1113

1214

@@ -190,7 +192,7 @@ def split_replicas(data):
190192
ax.plot(g0_values, Signal_observed, 'o', label='Observed Signal')
191193
ax.plot(fitting_curve_x, fitting_curve_y, '--', color='blue', alpha=0.6, label='Simulated Fitting Curve')
192194
ax.set_title(f'Observed vs. Simulated Fitting Curve for Replica {replica_index}')
193-
ax.legend()
195+
ax.legend(loc='best')
194196

195197
param_text = (f"$K_g$: {median_params[1] * 1e6:.2e} $M^{{-1}}$\n"
196198
f"$I_0$: {median_params[0]:.2e}\n"
@@ -199,9 +201,7 @@ def split_replicas(data):
199201
f"$RMSE$: {rmse:.3f}\n"
200202
f"$R^2$: {r_squared:.3f}")
201203

202-
ax.annotate(param_text, xy=(0.97, 0.95), xycoords='axes fraction', fontsize=10,
203-
ha='right', va='top', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5),
204-
multialignment='left')
204+
place_annotation_opposite_legend(ax, param_text)
205205

206206
if save_plots:
207207
plot_file = os.path.join(plots_dir, f"fit_plot_replica_{replica_index}.png")

interface_dba_merge_fits.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import matplotlib.pyplot as plt
77
from datetime import datetime
88
from pltstyle import create_plots # Import the create_plots function
9+
from matplotlib.transforms import Bbox
10+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
911

1012

1113
def format_value(value):
@@ -306,10 +308,9 @@ def run_dba_merge_fits(results_dir, outlier_relative_threshold, rmse_threshold_f
306308
f"$I_{{hd}}$: {avg_params[3]:.2e} $M^{{-1}}$ (STDEV: {stdev_params[3]:.2e})\n"
307309
f"$RMSE$: {format_value(rmse)}\n"
308310
f"$R^2$: {r_squared:.3f}")
309-
ax2.annotate(param_text, xy=(0.95, 0.05), xycoords='axes fraction', fontsize=10,
310-
ha='right', va='bottom', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5), multialignment='left')
311+
place_annotation_opposite_legend(ax2, param_text)
311312

312-
ax2.legend()
313+
ax2.legend(loc='best')
313314
fig2.tight_layout()
314315
if save_plots:
315316
save_plot(fig2, "averaged_fitting_plot.png", results_dir)

interface_gda_merge_fits.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from scipy.optimize import brentq
88
from datetime import datetime
99
from pltstyle import create_plots # Import the create_plots function
10+
from matplotlib.transforms import Bbox
11+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
1012

1113
def format_value(value):
1214
return f"{value:.0f}" if value > 10 else f"{value:.2f}"
@@ -226,7 +228,7 @@ def run_gda_merge_fits(results_dir, outlier_relative_threshold, rmse_threshold_f
226228
if len(outlier_indices) > 0:
227229
ax1.plot(concentrations[outlier_indices], signals[outlier_indices], 'x', color=colors[idx], markersize=8, label=f"Replica {idx + 1} Outliers")
228230

229-
ax1.legend(loc='best', bbox_to_anchor=(0.02, 0.98))
231+
ax1.legend(loc='best')
230232
fig1.tight_layout()
231233
if save_plots:
232234
save_plot(fig1, "all_replicas_fitting_plot_with_outliers.png", results_dir)
@@ -310,10 +312,9 @@ def run_gda_merge_fits(results_dir, outlier_relative_threshold, rmse_threshold_f
310312
f"$I_{{hd}}$: {avg_params[3]:.2e} $M^{{-1}}$ (STDEV: {stdev_params[3]:.2e})\n"
311313
f"$RMSE$: {format_value(rmse)}\n"
312314
f"$R^2$: {r_squared:.3f}")
313-
ax2.annotate(param_text, xy=(0.97, 0.04), xycoords='axes fraction', fontsize=10,
314-
ha='right', va='bottom', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5), multialignment='left')
315+
place_annotation_opposite_legend(ax2, param_text)
315316

316-
ax2.legend(loc='best', bbox_to_anchor=(0.02, 0.98))
317+
ax2.legend(loc='best')
317318
fig2.tight_layout()
318319
if save_plots:
319320
save_plot(fig2, "averaged_fitting_plot.png", results_dir)

interface_ida_merge_fits.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from datetime import datetime
88
from scipy.optimize import brentq
99
from pltstyle import create_plots
10+
from matplotlib.transforms import Bbox
11+
from plot_replica import place_annotation_safely, place_annotation_opposite_legend
1012

1113
def format_value(value):
1214
return f"{value:.0f}" if value > 10 else f"{value:.2f}"
@@ -329,11 +331,9 @@ def export_averaged_data(avg_concentrations, avg_signals, avg_fitting_curve_x, a
329331
f"$I_{{hd}}$: {avg_params[3]:.2e} $M^{{-1}}$ (STDEV: {stdev_params[3]:.2e})\n"
330332
f"$RMSE$: {format_value(rmse)}\n"
331333
f"$R^2$: {r_squared:.3f}")
332-
ax2.annotate(param_text, xy=(0.97, 0.95), xycoords='axes fraction', fontsize=10,
333-
ha='right', va='top', bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="lightgrey", alpha=0.5),
334-
multialignment='left')
334+
place_annotation_opposite_legend(ax2, param_text)
335335

336-
ax2.legend()
336+
ax2.legend(loc='best')
337337
fig2.tight_layout()
338338
if save_plots:
339339
save_plot(fig2, "averaged_fitting_plot.png", results_dir)

plot_replica.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pltstyle
66
import tkinter as tk
77
import re
8+
import numpy as np
9+
from matplotlib.transforms import Bbox
810

911
from pathlib import Path
1012
from tkinter import filedialog, messagebox
@@ -13,6 +15,80 @@
1315
from bmg_to_txt import read_bmg_xlsx, extract_concentration_vector
1416

1517

18+
def place_annotation_safely(ax, text, data_x, data_y, line2d=None, marker_size=12, margin=8, **annotate_kwargs):
19+
"""
20+
Place annotation box in a location that avoids overlap with legend, data points, and optionally a line.
21+
- data_x, data_y: arrays of data points (for scatter/markers)
22+
- line2d: optional, a matplotlib Line2D object to check for overlap with the line itself
23+
- marker_size: pixel size of marker to avoid (default 12)
24+
- margin: extra pixels to add around annotation box (default 8)
25+
"""
26+
renderer = ax.figure.canvas.get_renderer()
27+
corners = [
28+
(0.01, 0.99, {'xycoords': 'axes fraction', 'va': 'top', 'ha': 'left'}),
29+
(0.99, 0.99, {'xycoords': 'axes fraction', 'va': 'top', 'ha': 'right'}),
30+
(0.01, 0.01, {'xycoords': 'axes fraction', 'va': 'bottom', 'ha': 'left'}),
31+
(0.99, 0.01, {'xycoords': 'axes fraction', 'va': 'bottom', 'ha': 'right'}),
32+
]
33+
legend = ax.get_legend()
34+
legend_bbox = legend.get_window_extent(renderer) if legend else None
35+
# Data points as bboxes (expanded for marker size)
36+
data_disp = ax.transData.transform(np.column_stack([data_x, data_y]))
37+
data_bboxes = [Bbox.from_bounds(x-marker_size, y-marker_size, 2*marker_size, 2*marker_size) for x, y in data_disp]
38+
# If a line is provided, rasterize it to points and add bboxes
39+
line_bboxes = []
40+
if line2d is not None:
41+
line_x, line_y = line2d.get_data()
42+
line_disp = ax.transData.transform(np.column_stack([line_x, line_y]))
43+
for x, y in line_disp:
44+
line_bboxes.append(Bbox.from_bounds(x-2, y-2, 4, 4))
45+
best = None
46+
min_overlap = float('inf')
47+
for x, y, opts in corners:
48+
ann = ax.annotate(text, (x, y), bbox=dict(boxstyle='round', fc='w', ec='k', alpha=0.8), annotation_clip=False, **opts, **annotate_kwargs)
49+
ax.figure.canvas.draw()
50+
ann_bbox = ann.get_window_extent(renderer).expanded(1.05, 1.1).padded(margin)
51+
overlap = 0
52+
if legend_bbox and ann_bbox.overlaps(legend_bbox):
53+
overlap += 1e6
54+
for db in data_bboxes:
55+
if ann_bbox.overlaps(db):
56+
overlap += 1
57+
for lb in line_bboxes:
58+
if ann_bbox.overlaps(lb):
59+
overlap += 1
60+
if overlap == 0:
61+
return ann
62+
if overlap < min_overlap:
63+
min_overlap = overlap
64+
best = ann
65+
ann.remove()
66+
return best # fallback: least overlap
67+
68+
69+
def place_annotation_opposite_legend(ax, text, offset_frac=0.03, **annotate_kwargs):
70+
"""
71+
Place annotation box in the corner diagonally opposite to the legend, with a margin from axes.
72+
offset_frac: fraction of axes width/height to offset from the edge (default 0.03 = 3%)
73+
"""
74+
legend = ax.get_legend()
75+
loc = getattr(legend, '_loc', 'upper right') if legend else 'upper right'
76+
loc_map = {
77+
'upper right': (offset_frac, offset_frac, {'xycoords': 'axes fraction', 'va': 'bottom', 'ha': 'left'}),
78+
'upper left': (1-offset_frac, offset_frac, {'xycoords': 'axes fraction', 'va': 'bottom', 'ha': 'right'}),
79+
'lower left': (1-offset_frac, 1-offset_frac, {'xycoords': 'axes fraction', 'va': 'top', 'ha': 'right'}),
80+
'lower right': (offset_frac, 1-offset_frac, {'xycoords': 'axes fraction', 'va': 'top', 'ha': 'left'}),
81+
'best': (offset_frac, offset_frac, {'xycoords': 'axes fraction', 'va': 'bottom', 'ha': 'left'}),
82+
}
83+
x, y, opts = loc_map.get(loc, loc_map['upper right'])
84+
return ax.annotate(
85+
text, (x, y),
86+
bbox=dict(boxstyle='round', fc='w', ec='k', alpha=0.8),
87+
annotation_clip=False,
88+
**opts, **annotate_kwargs
89+
)
90+
91+
1692
def plot_all_replica(raw_data_path : str, robot_file_path : str, save_dir : str):
1793
"""this is for ONE excel file"""
1894
raw_data_path, robot_file_path, save_dir = map(Path, (raw_data_path, robot_file_path, save_dir))
@@ -25,7 +101,7 @@ def plot_all_replica(raw_data_path : str, robot_file_path : str, save_dir : str)
25101
fig, ax = pltstyle.create_plots(plot_title=plot_title)
26102
# ax.boxplot(data, tick_labels=concentration_vector, patch_artist=True)
27103
ax.plot(concentration_vector, data.values.T, label=[f'Replica {i+1}' for i in range(data.shape[0])])#
28-
ax.legend()
104+
ax.legend(loc='best')
29105

30106
# save plot of different analytes to separate folders
31107
save_dir_analyte = (save_dir / ' '.join(raw_data_path.stem.split('_')[1:4]))

0 commit comments

Comments
 (0)