Skip to content

Commit 9ec3441

Browse files
authored
Merge pull request #33 from LorenzLamm/remove_prints
Remove prints
2 parents 4a5be57 + 1995457 commit 9ec3441

8 files changed

Lines changed: 30 additions & 8 deletions

File tree

src/membrain_pick/clustering/mean_shift_inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def mean_shift_for_scores(
4545
)
4646
else:
4747
raise ValueError("Unknown method for mean shift clustering.")
48-
print("Found", out_pos.shape[0], "clusters.")
4948
return out_pos, out_p_num
5049

5150

src/membrain_pick/dataloading/mesh_partitioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def load_from_cache(cur_cache_path: str) -> Optional[Dict[str, np.ndarray]]:
508508
The loaded partitioning data if successful, None otherwise.
509509
"""
510510
if os.path.isfile(cur_cache_path):
511-
print(f"Loading partitioning data from {cur_cache_path}")
511+
# print(f"Loading partitioning data from {cur_cache_path}")
512512
return np.load(cur_cache_path)
513513
else:
514514
return None

src/membrain_pick/mesh_projections/mesh_conversion_wrappers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def mesh_for_tomo_mb_folder(
154154
mb_files = [
155155
os.path.join(mb_folder, f) for f in os.listdir(mb_folder) if f.endswith(".mrc")
156156
]
157-
print(mb_files)
158157

159158
if tomo is None:
160159
tomo = load_tomogram(tomo_file)

src/membrain_pick/napari_utils/surforama_cli_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def normalize_surface_values(surface_values, value_range=None):
213213
np.percentile(surface_values, cutoff_pct * 100),
214214
np.percentile(surface_values, (1 - cutoff_pct) * 100),
215215
)
216-
print("Normalized value range: ", value_range)
217216
normalized_values = (surface_values - value_range[0]) / (
218217
value_range[1] - value_range[0] + np.finfo(float).eps
219218
)

src/membrain_pick/networks/diffusion_net/geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def get_operators(verts, faces, k_eig=128, op_cache_dir=None, normals=None, over
471471
# If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
472472
# entry below more eigenvalues
473473
if overwrite_cache:
474-
print(" overwriting cache by request")
474+
# print(" overwriting cache by request")
475475
os.remove(search_path)
476476
break
477477

@@ -516,7 +516,7 @@ def read_sp_mat(prefix):
516516
break
517517

518518
except FileNotFoundError:
519-
print(" cache miss -- constructing operators")
519+
# print(" cache miss -- constructing operators")
520520
break
521521

522522
except Exception as E:

src/membrain_pick/optimization/diffusion_training_pylit.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytorch_lightning as pl
33
import torch
44
from torch.optim import Adam, SGD
5+
import matplotlib.pyplot as plt
56

67
from membrain_pick.networks.diffusion_net import DiffusionNet
78
from membrain_pick.clustering.mean_shift_utils import MeanShiftForwarder
@@ -30,11 +31,17 @@ def __init__(self,
3031
fixed_time=None,
3132
one_D_conv_first=False,
3233
clamp_diffusion=False,
34+
out_plot_file=None,
3335
visualize_diffusion=False,
3436
visualize_grad_rotations=False,
3537
visualize_grad_features=False):
3638
super().__init__()
3739
self.max_epochs = max_epochs
40+
self.epoch_losses = {
41+
"train": [],
42+
"val": []
43+
}
44+
self.out_plot_file = out_plot_file
3845
# Initialize the DiffusionNet with the given arguments.
3946
self.model = DiffusionNet(C_in=C_in,
4047
C_out=C_out,
@@ -132,7 +139,7 @@ def training_step(self, batch, batch_idx):
132139
# Log training loss
133140
self.total_train_loss += loss.detach()
134141
self.train_batches += 1
135-
print(f"Training loss: {loss}")
142+
# print(f"Training loss: {loss}")
136143
return loss
137144

138145
def validation_step(self, batch, batch_idx):
@@ -156,12 +163,25 @@ def validation_step(self, batch, batch_idx):
156163
def on_train_epoch_end(self):
157164
# Log the average training loss
158165
avg_train_loss = self.total_train_loss / self.train_batches
166+
print("Train epoch loss: ", avg_train_loss)
167+
self.epoch_losses["train"].append(avg_train_loss)
159168
self.log('train_loss', avg_train_loss)
160169

161170
def on_validation_epoch_end(self):
162171
# Log the average validation loss
163172
avg_val_loss = self.total_val_loss / self.val_batches
173+
print("Validation epoch loss: ", avg_val_loss)
174+
self.epoch_losses["val"].append(avg_val_loss)
164175
self.log('val_loss', avg_val_loss)
176+
self.plot_losses()
177+
178+
179+
def plot_losses(self):
180+
plt.figure()
181+
plt.plot(self.epoch_losses["train"], label="Train loss")
182+
plt.plot(self.epoch_losses["val"], label="Validation loss")
183+
plt.legend()
184+
plt.savefig(self.out_plot_file)
165185

166186

167187
def unpack_batch(batch):

src/membrain_pick/orientation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
import numpy as np
34
import scipy.spatial as spatial
@@ -14,6 +15,7 @@
1415
from membrain_seg.segmentation.dataloading.data_utils import load_tomogram
1516

1617

18+
1719
def orientation_from_mesh(coordinates, mesh):
1820
"""
1921
Get the orientation of a point cloud from a mesh.
@@ -39,7 +41,7 @@ def orientation_from_mesh(coordinates, mesh):
3941
distances, vertex_indices = tree.query(coordinates)
4042

4143
if np.any(distances > 200):
42-
print(
44+
logging.warning(
4345
"Warning: Some points are more than 200 units away from the mesh. This might be an error. Check rescaling factors."
4446
)
4547

src/membrain_pick/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def train(
4848
val_path = os.path.join(data_dir, "val")
4949
cache_dir_mb = os.path.join(training_dir, "mesh_cache")
5050
log_dir = os.path.join(training_dir, "logs")
51+
out_plot_file = os.path.join(training_dir, "plots", f"training_curves_{project_name}_{sub_name}.png")
52+
os.makedirs(os.path.join(training_dir, "plots"), exist_ok=True)
5153

5254
# Create the data module
5355
data_module = MemSegDiffusionNetDataModule(
@@ -85,6 +87,7 @@ def train(
8587
device=device,
8688
one_D_conv_first=one_D_conv_first,
8789
max_epochs=max_epochs,
90+
out_plot_file=out_plot_file,
8891
)
8992

9093
checkpointing_name = project_name + "_" + sub_name

0 commit comments

Comments
 (0)