Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DaruiJin committed Nov 21, 2024
1 parent 12c0bb8 commit 919d616
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 32 deletions.
1 change: 1 addition & 0 deletions aggregator_train_val/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Model:
aug: False
soft_labels: False
dim_age_embed: 32
cl_w: 20

Optimizer:
opt: lookahead_radam
Expand Down
2 changes: 1 addition & 1 deletion aggregator_train_val/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import Dataset, DataLoader
import ast
import pandas as pd
from utils import *
from .utils import *
from dataclasses import dataclass
from typing import Union

Expand Down
6 changes: 3 additions & 3 deletions aggregator_train_val/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import pytorch_lightning as pl
from nystrom_attention import NystromAttention

from Optimizer import create_optimizer
from utils import cross_entropy_torch, update_ema_variables, set_seed
from .Optimizer import create_optimizer
from .utils import cross_entropy_torch, update_ema_variables, set_seed


class TransLayer(nn.Module):
Expand Down Expand Up @@ -420,7 +420,7 @@ def on_test_epoch_end(self):
metrics[keys] = values.cpu().numpy()

result = pd.DataFrame([metrics])
result.to_csv(self.log_path / f'result{self.fold}.csv')
result.to_csv(self.log_path / f'result.csv')
self.test_step_outputs.clear()

def load_model(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from utils import *
from model_module import ModelModule
from data_module import DataModule
from .utils import *
from .model_module import ModelModule
from .data_module import DataModule


# Parse command-line arguments to set parameters for the script
Expand Down Expand Up @@ -37,7 +37,7 @@ def parse_arguments():


# Main function that orchestrates the training/testing process
def main(cfg):
def model_run(cfg):
# Set random seed for reproducibility
if cfg['General']['mode'] != 'train':
set_seed(cfg['General']['seed'])
Expand Down Expand Up @@ -73,7 +73,7 @@ def main(cfg):

else:
# Test the model using the latest checkpoints
latest_checkpoint_path = max(cfg['General']['log_path'].glob('*.ckpt'), key=os.path.getctime)
latest_checkpoint_path = max(cfg.log_path.glob('*.ckpt'), key=os.path.getctime)
print(f'Testing with checkpoint: {latest_checkpoint_path}')
loaded_model = model.load_from_checkpoint(checkpoint_path=latest_checkpoint_path, Data=cfg['Data'])
trainer.test(model=loaded_model, datamodule=data_module)
Expand Down Expand Up @@ -104,6 +104,6 @@ def main(cfg):
cfg.resume = args.resume

# Run the main function
main(cfg)
model_run(cfg)


95 changes: 95 additions & 0 deletions pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import glob
import yaml
import argparse
from preprocessing.tiling.main_create_tiles import tile_slide_images
from preprocessing.feature_extraction.get_features import extract_features
from aggregator_train_val.model_run import model_run
from aggregator_train_val.utils import read_yaml


def parse_arguments():
parser = argparse.ArgumentParser(description='End to end pipeline')
parser.add_argument('--tiling', action='store_true', help='Whether to perform tiling')
parser.add_argument('--feature_extraction', action='store_true', help='Whether to perform feature extraction')
parser.add_argument('--model_run', action='store_true', help='Whether to perform training/testing')
args, unknown = parser.parse_known_args()
if not (args.tiling or args.feature_extraction or args.model_run):
parser.error("At least one of --tiling, --feature_extraction, or --model_run must be True")

if args.tiling:
tiling_group = parser.add_argument_group('Tiling arguments')
tiling_group.add_argument('--slide_dir', type=str, help='path to the source slide image (.svs) directory')
tiling_group.add_argument('--slide_list', type=str, help='path to the source slide image list (.txt) to be processed')
tiling_group.add_argument('--tile_savedir', type=str, default='./tiles/', help='path to the save directory')

if args.feature_extraction:
feature_group = parser.add_argument_group('Feature extraction arguments')
feature_group.add_argument('--tile_dir', type=str, default=None,
help='path to the tile folder (.txt). If not provided, the tile folder will be generated from the tiling results')
feature_group.add_argument('--batchsize', type=int, default=768, help='batch size for inference')
feature_group.add_argument('--feature_dir', type=str, default='./features/', help='path to the save directory')

if args.model_run:
model_group = parser.add_argument_group('Model run arguments')
model_group.add_argument('--dataset', type=str, default=None,
help='Path to the dataset directory. If not provided, the dataset will be generated from the feature extraction results')
model_group.add_argument('--label', type=str, default='./labels/labels.csv',
help='Path to the slide label CSV file, which should contain columns including slide, family, probability vector, age, and location')
model_group.add_argument('--split', type=str, default=None,
help='Path to the dataset split file (YAML) containing train and test slide IDs, structured as {"train": [slide_id], "test": [slide_id]}. If not provided, the file will be generated from the dataset')
model_group.add_argument('--mode', type=str, default='train', help='Operation mode: train or test')
model_group.add_argument('--exp_name', type=str, default='default_exp', help='Identifier for the experiment')
model_group.add_argument('--output_dir', type=str, default='./predictions', help='Directory to save predictions')
model_group.add_argument('--resume', action='store_true', help='Resume training from the latest checkpoint')
model_group.add_argument('--config', type=str, default='./aggregator_train_val/config.yaml', help='Path to configuration file')

return parser.parse_args()


if __name__ == '__main__':
# Run the pipeline
args = parse_arguments()

if args.tiling:
tile_slide_images(source_dir=args.slide_dir, source_list=args.slide_list, save_dir=args.tile_savedir)

if args.feature_extraction:
if args.tile_dir is None:
try:
with open('./tile_list.txt', 'w') as f:
for item in glob.glob(os.path.join(args.tile_savedir, 'tiles','*')):
f.write("%s\n" % item)
args.tile_dir = './tile_list.txt'
except:
print('No tile folder found, please provide the tile folder path')
exit()

extract_features(split=args.tile_dir, batchsize=args.batchsize, feature_dir=args.feature_dir)

if args.model_run:
if args.dataset is None:
try:
args.dataset = os.path.join(args.feature_dir, 'pt_files')
except:
print('No dataset found, please provide the dataset path')
exit()
if args.split is None:
try:
testset = os.listdir(args.dataset)
testset = [os.path.splitext(item)[0] for item in testset]
with open('./split.yaml', 'w') as f:
yaml.dump({'train': [], 'test': testset}, f)
args.split = './split.yaml'
except:
print('No split file found, please provide the split file path')
exit()
cfg = read_yaml(args.config)
cfg['Data']['data_dir'] = args.dataset
cfg['Data']['data_split'] = args.split
cfg['Data']['label_file'] = args.label
cfg['General']['mode'] = args.mode
cfg['Model']['exp_name'] = args.exp_name
cfg['Model']['preds_save'] = args.output_dir
cfg['resume'] = args.resume
model_run(cfg)
21 changes: 14 additions & 7 deletions preprocessing/feature_extraction/get_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from timm.models.vision_transformer import VisionTransformer


# imagenet normalization
Expand Down Expand Up @@ -66,16 +65,12 @@ def save_hdf5(output_path, asset_dict, attr_dict=None, mode='w'):
return output_path


@click.command()
@click.option('--split', type=str, help='path to the split file (.txt)')
@click.option('--batchsize', type=int, help='batch size for inference')
@click.option('--feature_dir', type=str, help='path to the save directory')
# @click.option('--ckpt', type=str, help='path to the save directory')
def inference(split, batchsize, feature_dir):
def extract_features(split, batchsize=768, feature_dir='./features'):
slide_ls = [line.rstrip('\n') for line in open(split)]
os.remove(split) # remove the split file after reading
test_datat=roi_dataset(slide_ls)
database_loader = torch.utils.data.DataLoader(test_datat, batch_size=batchsize, shuffle=False)

# change the name to the model you want to use here
os.environ['HF_HOME'] = './model_cache'
model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
Expand Down Expand Up @@ -113,6 +108,18 @@ def inference(split, batchsize, feature_dir):
features = file['features'][:]
features = torch.from_numpy(features)
torch.save(features, os.path.join(feature_dir, 'pt_files', os.path.basename(h5file)+'.pt'))
file.close()

print('Feature extraction done!')

@click.command()
@click.option('--split', type=str, help='path to the split file (.txt)')
@click.option('--batchsize', type=int, help='batch size for inference')
@click.option('--feature_dir', type=str, help='path to the save directory')
# @click.option('--ckpt', type=str, help='path to the save directory')
def inference(split, batchsize, feature_dir):
extract_features(split, batchsize, feature_dir)


if __name__ == '__main__':
inference()
27 changes: 15 additions & 12 deletions preprocessing/tiling/main_create_tiles.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import os
import glob
import click
from slide_lib import segment_tiling
from .slide_lib import segment_tiling


@click.command()
@click.option('--source_dir', type=str, help='path to the source slide image (.svs) directory')
@click.option('--source_list', type=str, help='path to the source slide image (.svs) list to be processed')
@click.option('--save_dir', type=str, help='path to the save directory')
@click.option('--patch_size', type=int, default=256, help='patch size')
@click.option('--step_size', type=int, default=256, help='step size')
@click.option('--mag', type=int, default=20, help='magnification for patch extraction')
@click.option('--index', type=int, default=None)
def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size: int, step_size: int, mag: int, index: int) -> None:
def tile_slide_images(source_dir: str, source_list: list, save_dir: str, patch_size: int=256, step_size: int=256, mag: int=20, index: int=0) -> None:
"""
Tile whole slide images stored in the .svs/.ndpi/.scn format at the desired magnification.
Expand All @@ -21,7 +13,7 @@ def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size:
source_dir : str
Path to the source slide image (.svs) directory.
source_list : str
Path to the source slide image (.svs) list to be processed.
Path to the source slide image list (.txt) to be processed.
save_dir : str
Path to the save directory.
patch_size : int
Expand Down Expand Up @@ -58,6 +50,17 @@ def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size:
total_time = segment_tiling(**directories, patch_size=patch_size, mag_level=mag, step_size= step_size, index=index)
print(f"The average processing time for each slide is {total_time:.2f} seconds.")

@click.command()
@click.option('--source_dir', type=str, help='path to the source slide image (.svs) directory')
@click.option('--source_list', type=str, help='path to the source slide image (.svs) list to be processed')
@click.option('--save_dir', type=str, help='path to the save directory')
@click.option('--patch_size', type=int, default=256, help='patch size')
@click.option('--step_size', type=int, default=256, help='step size')
@click.option('--mag', type=int, default=20, help='magnification for patch extraction')
@click.option('--index', type=int, default=0)
def generate_tiles(source_dir, source_list, save_dir, patch_size, step_size, mag, index):
tile_slide_images(source_dir, source_list, save_dir, patch_size, step_size, mag, index)


if __name__ == '__main__':
batch_tiling()
generate_tiles()
3 changes: 2 additions & 1 deletion preprocessing/tiling/slide_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .utils import *
from .constants_color import *
from .constants_color import *
from .segment_patching import *
2 changes: 1 addition & 1 deletion preprocessing/tiling/slide_lib/segment_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from PIL import Image
import multiprocessing as mp
from slide_lib import *
from .utils import *


def segment(wsi: openslide.OpenSlide)->tuple[list, list, Image.Image, float]:
Expand Down
2 changes: 1 addition & 1 deletion preprocessing/tiling/slide_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Union
import skimage.filters as sk_filters
from PIL import Image
from slide_lib import PENS_RGB
from .constants_color import PENS_RGB
import multiprocessing as mp
import concurrent.futures

Expand Down

0 comments on commit 919d616

Please sign in to comment.