diff --git a/models/UNetModule.py b/models/UNetModule.py index 6e5e200..a7d0b18 100644 --- a/models/UNetModule.py +++ b/models/UNetModule.py @@ -222,8 +222,21 @@ def AUC_metric(self, y_true, y_pred): AUC = self.AUC_metrics.result() return AUC + +class SaveEveryNEpochs(tf.keras.callbacks.Callback): + def __init__(self, save_every=15): + super(SaveEveryNEpochs, self).__init__() + self.save_every = save_every + + def on_epoch_end(self, epoch, logs=None): + if (epoch + 1) % self.save_every == 0: # Save at the end of every 15 epochs + filepath = f"unet_model_epoch_{epoch + 1:02d}.h5" + self.model.save(filepath) + print(f"Checkpoint saved at epoch {epoch + 1}.") def train(self, train_dataset, val_dataset, epochs=20, batch_size=64, buffer_size=1000, val_subsplits=1, lr=0.001): + checkpoint_callback = SaveEveryNEpochs(save_every=15) + self.model.compile(optimizer='adam', # loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), loss=self.dice_loss, @@ -249,9 +262,10 @@ def train(self, train_dataset, val_dataset, epochs=20, batch_size=64, buffer_siz val_batches = val_images.batch(batch_size) - model_history = self.model.fit(train_batches, - epochs=epochs, - validation_data=val_batches) + model_history = self.model.fit(train_batches, + epochs=epochs, + validation_data=val_batches, + callbacks=[checkpoint_callback]) return model_history diff --git a/models/test_segmentation.py b/models/test_segmentation.py new file mode 100644 index 0000000..9de93d4 --- /dev/null +++ b/models/test_segmentation.py @@ -0,0 +1,73 @@ +import argparse +import json +import tensorflow as tf +from models.UNetModule import UNet +from src.data_loader import create_dataset + +def evaluate_segmentation(model, test_dataset): + """Evaluate the segmentation model on the test dataset. + + Args: + model: Trained segmentation model. + test_dataset: Dataset object for testing. + + Returns: + metrics: A dictionary containing evaluation metrics and loss. + """ + loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metric_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() + + total_loss = 0 + total_samples = 0 + + for images, labels in test_dataset: + predictions = model(images, training=False) + loss = loss_object(labels, predictions) + total_loss += loss.numpy() * len(images) + total_samples += len(images) + metric_accuracy.update_state(labels, predictions) + + avg_loss = total_loss / total_samples + accuracy = metric_accuracy.result().numpy() + + metrics = { + "average_loss": avg_loss, + "accuracy": accuracy, + } + return metrics + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a segmentation model on a test dataset.") + parser.add_argument("-model_type", "--model_type", type=str, default="unet", help="Type of the model (e.g., 'unet').") + parser.add_argument("-model_path", "--model_path", type=str, required=True, help="Path to the trained model weights.") + parser.add_argument("-data_path", "--data_path", type=str, required=True, help="Path to the test dataset.") + parser.add_argument("-output_path", "--output_path", type=str, default=None, help="Path to save evaluation results (optional).") + parser.add_argument("-batch_size", "--batch_size", type=int, default=16, help="Batch size for testing.") + args = parser.parse_args() + + # Load the testing dataset + test_dataset = create_dataset(args.data_path) + test_dataset = test_dataset.batch(args.batch_size) + + # Load the model + if args.model_type.lower() == "unet": + # Assuming input_channels and num_classes can be inferred + model = UNet(input_channels=16, num_classes=2) # Modify these as per the specific model requirements + else: + raise ValueError(f"Unsupported model type: {args.model_type}") + + model.load_weights(args.model_path) + + # Evaluate the model + metrics = evaluate_segmentation(model, test_dataset) + + # Print metrics to the console + print("Evaluation Results:") + for key, value in metrics.items(): + print(f"{key}: {value}") + + # Optionally save metrics to a file + if args.output_path: + with open(args.output_path, "w") as f: + json.dump(metrics, f, indent=4) + print(f"Results saved to {args.output_path}") diff --git a/models/train_model.py b/models/train_model.py index ab5a293..5a736cd 100644 --- a/models/train_model.py +++ b/models/train_model.py @@ -30,10 +30,11 @@ def train_model(data_dir: str = './data/raw_data/STARCOP_train_easy', max_boxes= # callbacks checkpoint = ModelCheckpoint( - filepath='best_model.keras', - save_best_only=True, + filepath='bounding_box_model_epoch_{epoch:02d}.keras', + save_best_only=False, monitor='val_loss', - mode='min' + mode='min', + save_freq='epoch' # Save model after each epoch ) lr_schedule = LearningRateScheduler(lambda epoch: 1e-3 * 0.95 ** epoch) # adjusts the learning rate for each epoch @@ -45,11 +46,9 @@ def train_model(data_dir: str = './data/raw_data/STARCOP_train_easy', max_boxes= model.fit( train_dataset, - epochs=1, + epochs=10, # Increased epochs to see model saving after each epoch steps_per_epoch=50, - # validation_data=val_dataset, - # validation_steps=17, - callbacks=[checkpoint, lr_schedule] + callbacks=[checkpoint, lr_schedule, tensorboard] ) one_batch = dataset.skip(50).take(1)