Skip to content

Commit

Permalink
refactor pipeline to reduce prediction code
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Mar 6, 2025
1 parent 934a58c commit 8363461
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 326 deletions.
5 changes: 4 additions & 1 deletion USGS_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,16 @@ def main(cfg: DictConfig):
train_df, validation_df = train_test_split(crop_annotations)

comet_logger = CometLogger(project_name=cfg.comet.project, workspace=cfg.comet.workspace)
preprocess_and_train_classification(
trained_model = preprocess_and_train_classification(
config=cfg,
train_df=train_df,
validation_df=validation_df,
comet_logger=comet_logger
)

comet_id = comet_logger.experiment.id
trained_model.trainer.save_checkpoint("{comet_id}.ckpt")

if __name__ == "__main__":
main()

12 changes: 7 additions & 5 deletions conf/classification_model/USGS.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
checkpoint:
checkpoint_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/checkpoints/
train_csv_folder:
train_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops
crop_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/
train_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops/
train_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/train
val_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/val

under_sample_ratio: 0
trainer:
fast_dev_run: False
max_epochs: 100
lr: 0.00001
batch_size: 16
workers: 10
lr: 0.0001
batch_size: 12
workers: 12
6 changes: 4 additions & 2 deletions conf/classification_model/finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
checkpoint:
checkpoint_dir: /blue/ewhite/b.weinstein/BOEM/classification/checkpoints
train_csv_folder: /blue/ewhite/b.weinstein/BOEM/annotations/train
train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/
train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27
train_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/train
val_crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/val

under_sample_ratio: 0
trainer:
fast_dev_run: False
Expand Down
7 changes: 3 additions & 4 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ label_studio:
url: "https://labelstudio.naturecast.org/"
folder_name: "/pgsql/retrieverdash/everglades-label-studio/everglades-data"
images_to_annotate_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27
annotated_images_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
csv_dir_train: /blue/ewhite/b.weinstein/BOEM/annotations/train
csv_dir_validation: /blue/ewhite/b.weinstein/BOEM/annotations/validation
instances:
Expand All @@ -34,11 +33,12 @@ predict:
patch_size: 1000
patch_overlap: 0
min_score: 0.4
batch_size: 48
batch_size: 32

pipeline:
confidence_threshold: 0.9
limit_empty_frac: 0.01
gpus: 2

propagate:
time_threshold_seconds: 5
Expand Down Expand Up @@ -77,7 +77,7 @@ active_learning:
n_images: 50
patch_size: 1000
patch_overlap: 0
min_score: 0.1
min_score: 0.2
model_checkpoint:
target_labels:
- "Object"
Expand All @@ -86,7 +86,6 @@ active_learning:
evaluation:
dask_client:
pool_limit: 500
gpus: 2

active_testing:
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27
Expand Down
131 changes: 6 additions & 125 deletions src/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,7 @@
import dask.array as da
import pandas as pd


def choose_test_images(image_dir, strategy, n=10, patch_size=512, patch_overlap=0, min_score=0.5, model=None, model_path=None, dask_client=None, target_labels=None, pool_limit=1000, batch_size=1, comet_logger=None):
"""Choose images to annotate.
Args:
evaluation (dict): A dictionary of evaluation metrics.
image_dir (str): The path to a directory of images.
strategy (str): The strategy for choosing images. Available strategies are:
- "random": Choose images randomly from the pool.
- "most-detections": Choose images with the most detections based on predictions.
- "target-labels": Choose images with target labels.
n (int, optional): The number of images to choose. Defaults to 10.
dask_client (dask.distributed.Client, optional): A Dask client for parallel processing. Defaults to None.
patch_size (int, optional): The size of the image patches to predict on. Defaults to 512.
patch_overlap (float, optional): The amount of overlap between image patches. Defaults to 0.1.
min_score (float, optional): The minimum score for a prediction to be included. Defaults to 0.1.
model (main.deepforest, optional): A trained deepforest model. Defaults to None.
model_path (str, optional): The path to the model checkpoint file. Defaults to None. Only used in combination with dask
target_labels: (list, optional): A list of target labels to filter images by. Defaults to None.
pool_limit (int, optional): The maximum number of images to consider. Defaults to 1000.
batch_size (int, optional): The batch size for prediction. Defaults to 1.
comet_logger (CometLogger, optional): A CometLogger object. Defaults to None.
Returns:
list: A list of image paths.
pd.DataFrame: A DataFrame of preannotations.
"""
pool = glob.glob(os.path.join(image_dir,"*")) # Get all images in the data directory
# Remove .csv files from the pool
pool = [image for image in pool if not image.endswith('.csv')]

# Remove crop dir
try:
pool.remove(os.path.join(image_dir,"crops"))
except ValueError:
pass

#subsample
if len(pool) > pool_limit:
pool = random.sample(pool, pool_limit)

if strategy=="random":
chosen_images = random.sample(pool, n)
preannotations = None
return chosen_images, None
elif strategy in ["most-detections","target-labels"]:
# Predict all images
if model_path is None:
raise ValueError("A model is required for the 'most-detections' or 'target-labels' strategy.")
if dask_client:
# load model on each client
def update_sys_path():
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dask_client.run(update_sys_path)

# Load model on each client
dask_pool = da.from_array(pool, chunks=len(pool)//len(dask_client.ncores()))
blocks = dask_pool.to_delayed().ravel()
block_futures = []
for block in blocks:
block_future = dask_client.submit(detection.predict,image_paths=block.compute(), patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path)
block_futures.append(block_future)
# Get results
dask_results = []
for block_result in block_futures:
block_result = block_result.result()
dask_results.append(pd.concat(block_result))
preannotations = pd.concat(dask_results)
else:
preannotations = detection.predict(model=model, image_paths=pool, patch_size=patch_size, patch_overlap=patch_overlap, batch_size=batch_size)
preannotations = pd.concat(preannotations)

if comet_logger:
comet_logger.log_table("active_testing_pool", preannotations)

print("There are {} preannotations before removing min score".format(preannotations.shape[0]))
preannotations = preannotations[preannotations["score"] >= min_score]

if strategy == "most-detections":
# Sort images by total number of predictions
chosen_images = preannotations.groupby("image_path").size().sort_values(ascending=False).head(n).index.tolist()
elif strategy == "target-labels":
# Filter images by target labels
chosen_images = preannotations[preannotations.label.isin(target_labels)].groupby("image_path").size().sort_values(ascending=False).head(n).index.tolist()
else:
raise ValueError("Invalid strategy. Must be one of 'random', 'most-detections', or 'target-labels'.")
# Get full path
chosen_images = [os.path.join(image_dir, image) for image in chosen_images]
else:
raise ValueError("Invalid strategy. Must be one of 'random', 'most-detections', or 'target-labels'.")

# Get preannotations for chosen images
chosen_preannotations = preannotations[preannotations["image_path"].isin(chosen_images)]
return chosen_images, chosen_preannotations

def human_review(predictions, min_score=0.1, confident_threshold=0.5):
def human_review(predictions, min_score=0.2, confident_threshold=0.5):
"""
Predict on images and divide into confident and uncertain predictions.
Args:
Expand All @@ -110,21 +16,21 @@ def human_review(predictions, min_score=0.1, confident_threshold=0.5):
tuple: A tuple of confident and uncertain predictions.
"""

predictions[predictions["score"] > min_score]
predictions[predictions["cropmodel_score"] > min_score]

# Split predictions into confident and uncertain
uncertain_predictions = predictions[
predictions["score"] <= confident_threshold]
predictions["cropmodel_score"] <= confident_threshold]

confident_predictions = predictions[
~predictions["image_path"].isin(
uncertain_predictions["image_path"])]

return confident_predictions, uncertain_predictions

def generate_pool_predictions(image_dir, patch_size=512, patch_overlap=0.1, min_score=0.1, model=None, model_path=None, dask_client=None, batch_size=16, comet_logger=None, pool_limit=1000, crop_model=None):
def generate_pool_predictions(image_dir, patch_size=512, patch_overlap=0.1, min_score=0.1, model=None, model_path=None, dask_client=None, batch_size=16, pool_limit=1000, crop_model=None):
"""
Generate predictions for the training pool.
Generate predictions for the flight pool.
Args:
image_dir (str): The path to a directory of images.
Expand Down Expand Up @@ -157,32 +63,7 @@ def generate_pool_predictions(image_dir, patch_size=512, patch_overlap=0.1, min_
except ValueError:
pass

if dask_client:
# load model on each client
def update_sys_path():
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dask_client.run(update_sys_path)

# Load model on each client
dask_pool = da.from_array(pool, chunks=len(pool) // len(dask_client.ncores()))
blocks = dask_pool.to_delayed().ravel()
block_futures = []
for block in blocks:
block_future = dask_client.submit(detection.predict, image_paths=block.compute(), patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path, crop_model=crop_model)
block_futures.append(block_future)
# Get results
dask_results = []
for block_result in block_futures:
block_result = block_result.result()
dask_results.append(pd.concat(block_result))
preannotations = pd.concat(dask_results)
else:
preannotations = detection.predict(m=model, image_paths=pool, patch_size=patch_size, patch_overlap=patch_overlap, batch_size=batch_size, crop_model=crop_model)
preannotations = pd.concat(preannotations)

if comet_logger:
comet_logger.experiment.log_table("active_training_pool", preannotations)
preannotations = detection.predict(m=model, model_path=model_path, image_paths=pool, patch_size=patch_size, patch_overlap=patch_overlap, batch_size=batch_size, crop_model=crop_model)

# Print the number of preannotations before removing min score
preannotations = preannotations[preannotations["score"] >= min_score]
Expand Down
35 changes: 22 additions & 13 deletions src/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepforest.model import CropModel
import torch
from torch.nn import functional as F

# Local imports
from src.label_studio import gather_data
Expand Down Expand Up @@ -98,17 +99,25 @@ def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_
label_count[label_name] += 1

model.trainer.fit(model)

# Compute confusion matrix and upload to cometml
image_dataset = []

dl = model.predict_dataloader(model.val_ds)

# Iterate over dl and get batched predictions
y_true = []
y_predicted = []
for index, (image,label) in enumerate(model.val_ds):
image_path, label = model.val_ds.imgs[index]
original_image = Image.open(image_path)
image_dataset += [original_image]
y_true += [label]
y_predicted += [model(image.unsqueeze(0)).argmax().item()]
image_dataset = []

for batch in dl:
images, labels = batch
outputs = model(images)
_, preds = torch.max(outputs, 1)

y_true.extend(labels.cpu().numpy())
y_predicted.extend(preds.cpu().numpy())

for image in images:
image_dataset.append(Image.fromarray(image.permute(1, 2, 0).cpu().numpy().astype('uint8')))

labels = model.val_ds.classes

# Log the confusion matrix to Comet
Expand Down Expand Up @@ -183,18 +192,18 @@ def preprocess_and_train_classification(config, train_df=None, validation_df=Non
model=loaded_model,
annotations=train_df,
root_dir=config.classification_model.train_image_dir,
save_dir=config.classification_model.crop_image_dir)
save_dir=config.classification_model.train_crop_image_dir)

preprocess_images(
model=loaded_model,
annotations=validation_df,
root_dir=config.classification_model.train_image_dir,
save_dir=config.classification_model.crop_image_dir)
save_dir=config.classification_model.val_crop_image_dir)

trained_model = train(
batch_size=config.classification_model.trainer.batch_size,
train_dir=config.classification_model.crop_image_dir,
val_dir=config.classification_model.crop_image_dir,
train_dir=config.classification_model.train_crop_image_dir,
val_dir=config.classification_model.val_crop_image_dir,
model=loaded_model,
fast_dev_run=config.classification_model.trainer.fast_dev_run,
max_epochs=config.classification_model.trainer.max_epochs,
Expand Down
26 changes: 11 additions & 15 deletions src/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def evaluate(model, test_csv, image_root_dir):
"""
# create trainer
devices = torch.cuda.device_count()
strategy = "ddp" if devices > 1 else None
strategy = "ddp" if devices > 1 else "auto"
model.create_trainer(num_nodes=1, devices=devices, strategy=strategy)
model.config["validation"]["csv_file"] = test_csv
model.config["validation"]["root_dir"] = image_root_dir
Expand Down Expand Up @@ -164,11 +164,11 @@ def train(model, train_annotations, test_annotations, train_image_dir, comet_log
model.config[key] = value

devices = torch.cuda.device_count()
strategy = "ddp" if devices > 1 else None
strategy = "ddp" if devices > 1 else "auto"
comet_logger.experiment.log_parameters(model.config)
comet_logger.experiment.log_table("train.csv", train_annotations)
comet_logger.experiment.log_table("test.csv", test_annotations)
model.create_trainer(logger=comet_logger, num_nodes=1, accelerator="gpu", strategy="ddp", devices=2)
model.create_trainer(logger=comet_logger, num_nodes=1, accelerator="gpu", strategy=strategy, devices=devices)

non_empty_train_annotations = read_file(model.config["train"]["csv_file"], root_dir=train_image_dir)
# Sanity check for debug
Expand All @@ -187,23 +187,19 @@ def train(model, train_annotations, test_annotations, train_image_dir, comet_log
visualize.plot_annotations(sample_validation_annotations_for_image, savedir=tmpdir)
comet_logger.experiment.log_image(os.path.join(tmpdir, filename),metadata={"name":filename,"context":'validation_images'})

with comet_logger.experiment.context_manager("detection"):
model.trainer.fit(model)
model.trainer.fit(model)

for image_path in test_annotations.image_path.unique():
prediction = model.predict_image(path = os.path.join(train_image_dir, image_path))
if prediction is None:
continue
visualize.plot_results(prediction, savedir=tmpdir)
comet_logger.experiment.log_image(os.path.join(tmpdir, image_path))

with comet_logger.experiment.context_manager("post-training prediction"):
for image_path in test_annotations.image_path.unique():
prediction = model.predict_image(path = os.path.join(train_image_dir, image_path))
if prediction is None:
continue
visualize.plot_results(prediction, savedir=tmpdir)
comet_logger.experiment.log_image(os.path.join(tmpdir, image_path))

return model

def fix_taxonomy(df):
df["label"] = "Object"
#df["label"] = df.label.replace('Turtle', 'Reptile')
#df["label"] = df.label.replace('Cetacean', 'Mammal')

return df

Expand Down
6 changes: 0 additions & 6 deletions src/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ def check_for_new_annotations(sftp_client, url, project_name, csv_dir, images_to

# Move annotated images out of local pool
if new_annotations is not None:
move_images(src_dir=images_to_annotate_dir, dst_dir=annotated_images_dir, annotations=new_annotations)
# Get any images from the server that are not in the images_to_annotate_dir
for image in new_annotations["image_path"].unique():
if not os.path.exists(os.path.join(annotated_images_dir, image)):
download_images(sftp_client=sftp_client, image_names=[image], folder_name=folder_name, local_image_dir=annotated_images_dir)

delete_completed_tasks(label_studio_project=label_studio_project)

else:
Expand Down
Loading

0 comments on commit 8363461

Please sign in to comment.