Skip to content

Commit

Permalink
Merge pull request #96 from cleapeter/tb-run-folders
Browse files Browse the repository at this point in the history
Move tensorboard event files to separate folders and copy config file
  • Loading branch information
wolny authored Nov 27, 2023
2 parents ec5d1d5 + 1c1f579 commit cbcb8cc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
8 changes: 5 additions & 3 deletions pytorch3dunet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from pytorch3dunet.unet3d.config import load_config
from pytorch3dunet.unet3d.config import load_config, copy_config
from pytorch3dunet.unet3d.trainer import create_trainer
from pytorch3dunet.unet3d.utils import get_logger

Expand All @@ -11,7 +11,7 @@

def main():
# Load and log experiment configuration
config = load_config()
config, config_path = load_config()
logger.info(config)

manual_seed = config.get('manual_seed', None)
Expand All @@ -23,8 +23,10 @@ def main():
# see https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.deterministic = True

# create trainer
# Create trainer
trainer = create_trainer(config)
# Copy config file
copy_config(config, config_path)
# Start training
trainer.fit()

Expand Down
21 changes: 19 additions & 2 deletions pytorch3dunet/unet3d/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import os
import shutil
import torch
import yaml

Expand All @@ -25,7 +26,23 @@ def load_config():
else:
logger.warning('CUDA not available, using CPU')
config['device'] = 'cpu'
return config
return config, args.config


def copy_config(config, config_path):
"""Copies the config file to the checkpoint folder."""

def _get_last_subfolder_path(path):
subfolders = [f.path for f in os.scandir(path) if f.is_dir()]
return max(subfolders, default=None)

checkpoint_dir = os.path.join(
config['trainer'].pop('checkpoint_dir'), 'logs')
last_run_dir = _get_last_subfolder_path(checkpoint_dir)
config_file_name = os.path.basename(config_path)

if last_run_dir:
shutil.copy2(config_path, os.path.join(last_run_dir, config_file_name))


def _load_config_yaml(config_file):
Expand Down
9 changes: 7 additions & 2 deletions pytorch3dunet/unet3d/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from pytorch3dunet.datasets.utils import get_train_loaders
from pytorch3dunet.unet3d.losses import get_loss_criterion
Expand Down Expand Up @@ -115,7 +115,12 @@ def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterio
else:
self.best_eval_score = float('+inf')

self.writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, 'logs'))
self.writer = SummaryWriter(
log_dir=os.path.join(
checkpoint_dir, 'logs',
datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
)

assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided'
self.tensorboard_formatter = tensorboard_formatter
Expand Down

0 comments on commit cbcb8cc

Please sign in to comment.