Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions models/UNetModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,21 @@ def AUC_metric(self, y_true, y_pred):
AUC = self.AUC_metrics.result()

return AUC

class SaveEveryNEpochs(tf.keras.callbacks.Callback):
def __init__(self, save_every=15):
super(SaveEveryNEpochs, self).__init__()
self.save_every = save_every

def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.save_every == 0: # Save at the end of every 15 epochs
filepath = f"unet_model_epoch_{epoch + 1:02d}.h5"
self.model.save(filepath)
print(f"Checkpoint saved at epoch {epoch + 1}.")

def train(self, train_dataset, val_dataset, epochs=20, batch_size=64, buffer_size=1000, val_subsplits=1, lr=0.001):
checkpoint_callback = SaveEveryNEpochs(save_every=15)

self.model.compile(optimizer='adam',
# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
loss=self.dice_loss,
Expand All @@ -249,9 +262,10 @@ def train(self, train_dataset, val_dataset, epochs=20, batch_size=64, buffer_siz

val_batches = val_images.batch(batch_size)

model_history = self.model.fit(train_batches,
epochs=epochs,
validation_data=val_batches)
model_history = self.model.fit(train_batches,
epochs=epochs,
validation_data=val_batches,
callbacks=[checkpoint_callback])

return model_history

Expand Down
73 changes: 73 additions & 0 deletions models/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import json
import tensorflow as tf
from models.UNetModule import UNet
from src.data_loader import create_dataset

def evaluate_segmentation(model, test_dataset):
"""Evaluate the segmentation model on the test dataset.

Args:
model: Trained segmentation model.
test_dataset: Dataset object for testing.

Returns:
metrics: A dictionary containing evaluation metrics and loss.
"""
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

total_loss = 0
total_samples = 0

for images, labels in test_dataset:
predictions = model(images, training=False)
loss = loss_object(labels, predictions)
total_loss += loss.numpy() * len(images)
total_samples += len(images)
metric_accuracy.update_state(labels, predictions)

avg_loss = total_loss / total_samples
accuracy = metric_accuracy.result().numpy()

metrics = {
"average_loss": avg_loss,
"accuracy": accuracy,
}
return metrics

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a segmentation model on a test dataset.")
parser.add_argument("-model_type", "--model_type", type=str, default="unet", help="Type of the model (e.g., 'unet').")
parser.add_argument("-model_path", "--model_path", type=str, required=True, help="Path to the trained model weights.")
parser.add_argument("-data_path", "--data_path", type=str, required=True, help="Path to the test dataset.")
parser.add_argument("-output_path", "--output_path", type=str, default=None, help="Path to save evaluation results (optional).")
parser.add_argument("-batch_size", "--batch_size", type=int, default=16, help="Batch size for testing.")
args = parser.parse_args()

# Load the testing dataset
test_dataset = create_dataset(args.data_path)
test_dataset = test_dataset.batch(args.batch_size)

# Load the model
if args.model_type.lower() == "unet":
# Assuming input_channels and num_classes can be inferred
model = UNet(input_channels=16, num_classes=2) # Modify these as per the specific model requirements
else:
raise ValueError(f"Unsupported model type: {args.model_type}")

model.load_weights(args.model_path)

# Evaluate the model
metrics = evaluate_segmentation(model, test_dataset)

# Print metrics to the console
print("Evaluation Results:")
for key, value in metrics.items():
print(f"{key}: {value}")

# Optionally save metrics to a file
if args.output_path:
with open(args.output_path, "w") as f:
json.dump(metrics, f, indent=4)
print(f"Results saved to {args.output_path}")
13 changes: 6 additions & 7 deletions models/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def train_model(data_dir: str = './data/raw_data/STARCOP_train_easy', max_boxes=

# callbacks
checkpoint = ModelCheckpoint(
filepath='best_model.keras',
save_best_only=True,
filepath='bounding_box_model_epoch_{epoch:02d}.keras',
save_best_only=False,
monitor='val_loss',
mode='min'
mode='min',
save_freq='epoch' # Save model after each epoch
)

lr_schedule = LearningRateScheduler(lambda epoch: 1e-3 * 0.95 ** epoch) # adjusts the learning rate for each epoch
Expand All @@ -45,11 +46,9 @@ def train_model(data_dir: str = './data/raw_data/STARCOP_train_easy', max_boxes=

model.fit(
train_dataset,
epochs=1,
epochs=10, # Increased epochs to see model saving after each epoch
steps_per_epoch=50,
# validation_data=val_dataset,
# validation_steps=17,
callbacks=[checkpoint, lr_schedule]
callbacks=[checkpoint, lr_schedule, tensorboard]
)

one_batch = dataset.skip(50).take(1)
Expand Down