Skip to content

Commit

Permalink
USGS classfication
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Mar 1, 2025
1 parent f36706e commit b391d84
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 8 deletions.
13 changes: 9 additions & 4 deletions USGS_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import hydra
from omegaconf import DictConfig

@hydra.main(config_path="conf/classification_model", config_name="USGS")
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
classification_cfg = cfg.classification
savedir = classification_cfg.savedir
# Override the classification_model config with USGS.yaml
hydra.compose(config_name="config", overrides=["+classification_model=@conf/classification_model/USGS"])

classification_cfg = cfg.classification_model

# From the detection script
savedir = "/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops"
train = pd.read_csv(os.path.join(savedir, "train.csv"))
test = pd.read_csv(os.path.join(savedir, "test.csv"))

comet_logger = CometLogger(project_name=classification_cfg.project_name, workspace=classification_cfg.workspace)
comet_logger = CometLogger(project_name=cfg.project, workspace=cfg.workspace)
preprocess_and_train_classification(
config=cfg,
train_df=train,
Expand Down
3 changes: 2 additions & 1 deletion conf/classification_model/USGS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ classification_model:
trainer:
fast_dev_run: True
max_epochs: 1
lr: 0.00001
lr: 0.00001
batch_size: 16
3 changes: 2 additions & 1 deletion conf/classification_model/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ classification_model:
trainer:
fast_dev_run: True
max_epochs: 1
lr: 0.00001
lr: 0.00001
batch_size: 16
5 changes: 4 additions & 1 deletion src/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load(checkpoint=None, annotations=None, checkpoint_dir=None, lr=0.0001, num_

return loaded_model

def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_epochs=10):
def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_epochs=10, batch_size=4):
"""Train a model on labeled images.
Args:
model (CropModel): A CropModel object.
Expand All @@ -67,10 +67,12 @@ def train(model, train_dir, val_dir, comet_logger=None, fast_dev_run=False, max_
fast_dev_run (bool): Whether to run a fast development run.
max_epochs (int): The maximum number of epochs to train for.
comet_logger (CometLogger): A CometLogger object.
batch_size (int): The batch size for training.
Returns:
main.deepforest: A trained deepforest model.
"""
model.batch_size = batch_size
model.create_trainer(logger=comet_logger, fast_dev_run=fast_dev_run, max_epochs=max_epochs)

# Get the data stored from the write_crops processing.
Expand Down Expand Up @@ -171,6 +173,7 @@ def preprocess_and_train_classification(config, train_df=None, validation_df=Non
save_dir=config.classification_model.crop_image_dir)

trained_model = train(
batch_size=config.classification_model.trainer.batch_size,
train_dir=config.classification_model.crop_image_dir,
val_dir=config.classification_model.crop_image_dir,
model=loaded_model,
Expand Down
3 changes: 3 additions & 0 deletions src/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def preprocess_images(
if not os.path.exists(root_dir):
raise FileNotFoundError(f"Root directory not found: {root_dir}")

# Remove any annotations with xmin == xmax
annotations = annotations[annotations.xmin != annotations.xmax]

os.makedirs(save_dir, exist_ok=True)

crop_annotations = []
Expand Down
2 changes: 1 addition & 1 deletion src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def run(self):
# Human review
new_review_annotations = self.check_new_annotations("review")
self.review_annotations = label_studio.gather_data(self.config.label_studio.instances.review.csv_dir)
self.comet_logger.experiment.log_table(tabular_data=self.review_annotations, name="human_reviewed_annotations.csv")
self.comet_logger.experiment.log_table(tabular_data=self.review_annotations, filename="human_reviewed_annotations.csv")

if new_val_annotations is None:
if self.config.force_upload:
Expand Down
1 change: 1 addition & 0 deletions submit_USGS.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ source activate BOEM
cd ~/BOEM/
#python prepare_USGS.py
srun python USGS_backbone.py --batch_size 12 --workers 16

21 changes: 21 additions & 0 deletions submit_USGS_classification.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash
#SBATCH --job-name=BOEM_USGS # Job name
#SBATCH --mail-type=END # Mail events
#SBATCH [email protected] # Where to send mail
#SBATCH --account=ewhite
#SBATCH --nodes=1 # Number of MPI ran
#SBATCH --cpus-per-task=16
#SBATCH --mem=150GB
#SBATCH --time=48:00:00 #Time limit hrs:min:sec
#SBATCH --output=/home/b.weinstein/logs/BOEM%j.out # Standard output and error log
#SBATCH --error=/home/b.weinstein/logs/BOEM%j.err
#SBATCH --partition=gpu
#SBATCH --ntasks-per-node=4
#SBATCH --gpus=4

source activate BOEM

cd ~/BOEM/
#python prepare_USGS.py
srun python USGS_classification.py

0 comments on commit b391d84

Please sign in to comment.