22import pytorch_lightning as pl
33import torch
44from torch .optim import Adam , SGD
5+ import matplotlib .pyplot as plt
56
67from membrain_pick .networks .diffusion_net import DiffusionNet
78from membrain_pick .clustering .mean_shift_utils import MeanShiftForwarder
@@ -30,11 +31,17 @@ def __init__(self,
3031 fixed_time = None ,
3132 one_D_conv_first = False ,
3233 clamp_diffusion = False ,
34+ out_plot_file = None ,
3335 visualize_diffusion = False ,
3436 visualize_grad_rotations = False ,
3537 visualize_grad_features = False ):
3638 super ().__init__ ()
3739 self .max_epochs = max_epochs
40+ self .epoch_losses = {
41+ "train" : [],
42+ "val" : []
43+ }
44+ self .out_plot_file = out_plot_file
3845 # Initialize the DiffusionNet with the given arguments.
3946 self .model = DiffusionNet (C_in = C_in ,
4047 C_out = C_out ,
@@ -132,7 +139,7 @@ def training_step(self, batch, batch_idx):
132139 # Log training loss
133140 self .total_train_loss += loss .detach ()
134141 self .train_batches += 1
135- print (f"Training loss: { loss } " )
142+ # print(f"Training loss: {loss}")
136143 return loss
137144
138145 def validation_step (self , batch , batch_idx ):
@@ -156,12 +163,25 @@ def validation_step(self, batch, batch_idx):
156163 def on_train_epoch_end (self ):
157164 # Log the average training loss
158165 avg_train_loss = self .total_train_loss / self .train_batches
166+ print ("Train epoch loss: " , avg_train_loss )
167+ self .epoch_losses ["train" ].append (avg_train_loss )
159168 self .log ('train_loss' , avg_train_loss )
160169
161170 def on_validation_epoch_end (self ):
162171 # Log the average validation loss
163172 avg_val_loss = self .total_val_loss / self .val_batches
173+ print ("Validation epoch loss: " , avg_val_loss )
174+ self .epoch_losses ["val" ].append (avg_val_loss )
164175 self .log ('val_loss' , avg_val_loss )
176+ self .plot_losses ()
177+
178+
179+ def plot_losses (self ):
180+ plt .figure ()
181+ plt .plot (self .epoch_losses ["train" ], label = "Train loss" )
182+ plt .plot (self .epoch_losses ["val" ], label = "Validation loss" )
183+ plt .legend ()
184+ plt .savefig (self .out_plot_file )
165185
166186
167187def unpack_batch (batch ):
0 commit comments