Skip to content

Commit

Permalink
create new classification tests to cover adding additional classes
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Mar 7, 2025
1 parent 8363461 commit d85ac0a
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 113 deletions.
2 changes: 1 addition & 1 deletion USGS_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main(cfg: DictConfig):
)

comet_id = comet_logger.experiment.id
trained_model.trainer.save_checkpoint("{comet_id}.ckpt")
trained_model.trainer.save_checkpoint(f"{comet_id}.ckpt")

if __name__ == "__main__":
main()
Expand Down
3 changes: 1 addition & 2 deletions conf/classification_model/USGS.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
checkpoint:
checkpoint:
checkpoint_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/checkpoints/
train_csv_folder:
train_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops/
train_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/train
val_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/val
Expand Down
3 changes: 1 addition & 2 deletions conf/classification_model/finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
checkpoint:
checkpoint: c1d69941aa5a4308bf1a75fda12ac6ec
checkpoint_dir: /blue/ewhite/b.weinstein/BOEM/classification/checkpoints
train_csv_folder: /blue/ewhite/b.weinstein/BOEM/annotations/train
train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27
train_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/train
val_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/val
Expand Down
5 changes: 1 addition & 4 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ predict:
pipeline:
confidence_threshold: 0.9
limit_empty_frac: 0.01
gpus: 2
gpus: 1

propagate:
time_threshold_seconds: 5
Expand All @@ -64,12 +64,9 @@ detection_model:
val_accuracy_interval: 5

pipeline_evaluation:
detect_ground_truth_dir: /blue/ewhite/b.weinstein/BOEM/annotations/validation
classify_ground_truth_dir: /blue/ewhite/b.weinstein/BOEM/annotations/validation
# This is an average mAP threshold for now, but we may want to add a per-iou threshold in the future
detection_true_positive_threshold: 0.8
classification_avg_score: 0.5
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated

active_learning:
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27
Expand Down
3 changes: 2 additions & 1 deletion src/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def generate_pool_predictions(image_dir, patch_size=512, patch_overlap=0.1, min_
pass

preannotations = detection.predict(m=model, model_path=model_path, image_paths=pool, patch_size=patch_size, patch_overlap=patch_overlap, batch_size=batch_size, crop_model=crop_model)

preannotations = pd.concat(preannotations)

# Print the number of preannotations before removing min score
preannotations = preannotations[preannotations["score"] >= min_score]

Expand Down
164 changes: 96 additions & 68 deletions src/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@
import torch
from torch.nn import functional as F

# Local imports
from src.label_studio import gather_data

def create_train_test(annotations):
return annotations.sample(frac=0.8, random_state=1), annotations.drop(
annotations.sample(frac=0.8, random_state=1).index)

def get_latest_checkpoint(checkpoint_dir, annotations, lr=0.0001, num_classes=None):
#Get model with latest checkpoint dir, if none exist make a new model
if os.path.exists(checkpoint_dir):
Expand Down Expand Up @@ -44,26 +37,36 @@ def get_latest_checkpoint(checkpoint_dir, annotations, lr=0.0001, num_classes=No

return m

def load(checkpoint=None, annotations=None, checkpoint_dir=None, lr=0.0001, num_classes=None):
def load(checkpoint=None, annotations=None, checkpoint_dir=None, num_classes=None):
if checkpoint:
if num_classes:
loaded_model = CropModel(checkpoint, num_classes=num_classes, lr=lr)
else:
loaded_model = CropModel(checkpoint, num_classes=len(annotations["label"].unique()), lr=lr)
loaded_model = CropModel.load_from_checkpoint(checkpoint)
num_classes = loaded_model.num_classes
elif checkpoint_dir:
loaded_model = get_latest_checkpoint(
checkpoint_dir,
num_classes=num_classes,
annotations=annotations)
loaded_model = get_latest_checkpoint(
checkpoint_dir,
num_classes=num_classes,
annotations=annotations)
num_classes = loaded_model.num_classes
else:
raise ValueError("No checkpoint or checkpoint directory found.")
loaded_model = CropModel(num_classes=num_classes)

if annotations is not None:
if annotations.label.unique().tolist() != num_classes:
finetune_model = CropModel(num_classes=len(annotations.label.unique()))

# Strip the last layer off the checkout model and replace with new layer
num_ftrs = loaded_model.model.fc.in_features
loaded_model.model.fc = torch.nn.Linear(num_ftrs, len(annotations["label"].unique()))
finetune_model.model = loaded_model.model
loaded_model = finetune_model

return loaded_model

def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_epochs=10, batch_size=4, workers=0):
def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_epochs=10, batch_size=4, workers=0, lr=0.0001):
"""Train a model on labeled images.
Args:
model (CropModel): A CropModel object.
lr: Learning rate for training.
train_dir (str): The directory containing the training images.
val_dir (str): The directory containing the validation images.
fast_dev_run (bool): Whether to run a fast development run.
Expand All @@ -76,6 +79,7 @@ def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_
"""
model.batch_size = batch_size
model.num_workers = workers
model.lr = lr

devices = torch.cuda.device_count()
model.create_trainer(logger=comet_logger, fast_dev_run=fast_dev_run, max_epochs=max_epochs, num_nodes=1, devices = devices)
Expand All @@ -95,7 +99,8 @@ def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_
label_count[label_name] = 0
if label_count[label_name] < 10:
image_name = os.path.basename(image_path)
comet_logger.experiment.log_image(image_path, name=f"{label_name}_{image_name}")
if comet_logger:
comet_logger.experiment.log_image(image_path, name=f"{label_name}_{image_name}")
label_count[label_name] += 1

model.trainer.fit(model)
Expand All @@ -115,18 +120,18 @@ def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_
y_true.extend(labels.cpu().numpy())
y_predicted.extend(preds.cpu().numpy())

for image in images:
image_dataset.append(Image.fromarray(image.permute(1, 2, 0).cpu().numpy().astype('uint8')))

labels = model.val_ds.classes
images = [model.val_ds.imgs[i][0] for i in range(len(model.val_ds.imgs))]
image_dataset = [Image.open(image) for image in images]

# Log the confusion matrix to Comet
comet_logger.experiment.log_confusion_matrix(
y_true=y_true,
y_predicted=y_predicted,
images=image_dataset,
labels=labels,
)
if comet_logger:
comet_logger.experiment.log_confusion_matrix(
y_true=y_true,
y_predicted=y_predicted,
images=image_dataset,
labels=labels,
)

return model

Expand All @@ -137,7 +142,7 @@ def preprocess_images(model, annotations, root_dir, save_dir):
# Remove any negative values
annotations = annotations[(annotations['xmin'] >= 0) & (annotations['ymin'] >= 0) & (annotations['xmax'] >= 0) & (annotations['ymax'] >= 0)]
boxes = annotations[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist()

# Expand by 20 pixels on all sides
boxes = [[box[0]-20, box[1]-20, box[2]+20, box[3]+20] for box in boxes]

Expand All @@ -149,66 +154,89 @@ def preprocess_images(model, annotations, root_dir, save_dir):

model.write_crops(boxes=boxes, root_dir=root_dir, images=images, labels=labels, savedir=save_dir)

def preprocess_and_train_classification(config, train_df=None, validation_df=None, comet_logger=None):
def preprocess_and_train(
train_df,
validation_df,
checkpoint,
checkpoint_dir,
train_image_dir,
train_crop_image_dir,
val_crop_image_dir,
lr=0.0001,
batch_size=4,
fast_dev_run=False,
max_epochs=10,
workers=0,
comet_logger=None
):
"""Preprocess data and train a crop model.
Args:
config: Configuration object containing training parameters
train_df (pd.DataFrame): A DataFrame containing training annotations.
validation_df (pd.DataFrame): A DataFrame containing validation annotations.
comet_logger: CometLogger object for logging experiments
checkpoint (str): Path to the checkpoint file.
checkpoint_dir (str): Directory containing checkpoint files.
image_dir (str): Directory containing images to be cropped.
train_crop_image_dir (str): Directory to save cropped training images.
val_crop_image_dir (str): Directory to save cropped validation images.
lr (float): Learning rate for training.
batch_size (int): Batch size for training.
fast_dev_run (bool): Whether to run a fast development run.
max_epochs (int): Maximum number of epochs for training.
workers (int): Number of workers for data loading.
train_df (pd.DataFrame): DataFrame containing training annotations.
validation_df (pd.DataFrame): DataFrame containing validation annotations.
comet_logger: CometLogger object for logging experiments.
Returns:
trained_model: Trained model object
trained_model: Trained model object.
"""
# Get and split annotations
if train_df is None:
annotations = gather_data(config.classification_model.train_csv_folder)
else:
annotations = train_df

num_classes = len(annotations["label"].unique())

# Remove the empty frames
annotations = annotations[~(annotations.label.astype(str)== "0")]
annotations = annotations[annotations.label != "FalsePositive"]

if validation_df is None:
train_df, validation_df = create_train_test(annotations)
else:
train_df = annotations[~annotations["image_path"].
isin(validation_df["image_path"])]
num_classes = len(train_df["label"].unique())

if train_df.empty:
# Load existing model
loaded_model = load(
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
num_classes=num_classes,
)
return loaded_model

# Load existing model
loaded_model = load(
checkpoint=config.classification_model.checkpoint,
checkpoint_dir=config.classification_model.checkpoint_dir,
annotations=annotations,
lr=config.classification_model.trainer.lr,
num_classes=num_classes
)
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
annotations=train_df,
num_classes=num_classes,
)

if train_df.empty:
print("No annotations found.")
return loaded_model

# Preprocess train and validation data
preprocess_images(
model=loaded_model,
annotations=train_df,
root_dir=config.classification_model.train_image_dir,
save_dir=config.classification_model.train_crop_image_dir)
root_dir=train_image_dir,
save_dir=train_crop_image_dir
)

preprocess_images(
model=loaded_model,
annotations=validation_df,
root_dir=config.classification_model.train_image_dir,
save_dir=config.classification_model.val_crop_image_dir)
root_dir=train_image_dir,
save_dir=val_crop_image_dir
)

trained_model = train(
batch_size=config.classification_model.trainer.batch_size,
train_dir=config.classification_model.train_crop_image_dir,
val_dir=config.classification_model.val_crop_image_dir,
batch_size=batch_size,
train_dir=train_crop_image_dir,
val_dir=val_crop_image_dir,
model=loaded_model,
fast_dev_run=config.classification_model.trainer.fast_dev_run,
max_epochs=config.classification_model.trainer.max_epochs,
fast_dev_run=fast_dev_run,
max_epochs=max_epochs,
comet_logger=comet_logger,
workers=config.classification_model.trainer.workers
)
workers=workers,
)

return trained_model
Loading

0 comments on commit d85ac0a

Please sign in to comment.