From 919d616b519f2b2492fb5567db5b6a3e9c3bb51a Mon Sep 17 00:00:00 2001 From: daruijin Date: Thu, 21 Nov 2024 13:44:45 +0100 Subject: [PATCH] initial commit --- aggregator_train_val/config.yaml | 1 + aggregator_train_val/data_module.py | 2 +- aggregator_train_val/model_module.py | 6 +- .../{main.py => model_run.py} | 12 +-- pipeline.py | 95 +++++++++++++++++++ .../feature_extraction/get_features.py | 21 ++-- preprocessing/tiling/main_create_tiles.py | 27 +++--- preprocessing/tiling/slide_lib/__init__.py | 3 +- .../tiling/slide_lib/segment_patching.py | 2 +- preprocessing/tiling/slide_lib/utils.py | 2 +- 10 files changed, 139 insertions(+), 32 deletions(-) rename aggregator_train_val/{main.py => model_run.py} (95%) create mode 100644 pipeline.py diff --git a/aggregator_train_val/config.yaml b/aggregator_train_val/config.yaml index 3e3ba7c..279f550 100644 --- a/aggregator_train_val/config.yaml +++ b/aggregator_train_val/config.yaml @@ -31,6 +31,7 @@ Model: aug: False soft_labels: False dim_age_embed: 32 + cl_w: 20 Optimizer: opt: lookahead_radam diff --git a/aggregator_train_val/data_module.py b/aggregator_train_val/data_module.py index 3dd9e62..421acb0 100644 --- a/aggregator_train_val/data_module.py +++ b/aggregator_train_val/data_module.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset, DataLoader import ast import pandas as pd -from utils import * +from .utils import * from dataclasses import dataclass from typing import Union diff --git a/aggregator_train_val/model_module.py b/aggregator_train_val/model_module.py index 6fd539c..ae4b5fd 100644 --- a/aggregator_train_val/model_module.py +++ b/aggregator_train_val/model_module.py @@ -13,8 +13,8 @@ import pytorch_lightning as pl from nystrom_attention import NystromAttention -from Optimizer import create_optimizer -from utils import cross_entropy_torch, update_ema_variables, set_seed +from .Optimizer import create_optimizer +from .utils import cross_entropy_torch, update_ema_variables, set_seed class TransLayer(nn.Module): @@ -420,7 +420,7 @@ def on_test_epoch_end(self): metrics[keys] = values.cpu().numpy() result = pd.DataFrame([metrics]) - result.to_csv(self.log_path / f'result{self.fold}.csv') + result.to_csv(self.log_path / f'result.csv') self.test_step_outputs.clear() def load_model(self): diff --git a/aggregator_train_val/main.py b/aggregator_train_val/model_run.py similarity index 95% rename from aggregator_train_val/main.py rename to aggregator_train_val/model_run.py index ce8bec9..a37bfe1 100644 --- a/aggregator_train_val/main.py +++ b/aggregator_train_val/model_run.py @@ -4,9 +4,9 @@ import pytorch_lightning as pl from pytorch_lightning import Trainer -from utils import * -from model_module import ModelModule -from data_module import DataModule +from .utils import * +from .model_module import ModelModule +from .data_module import DataModule # Parse command-line arguments to set parameters for the script @@ -37,7 +37,7 @@ def parse_arguments(): # Main function that orchestrates the training/testing process -def main(cfg): +def model_run(cfg): # Set random seed for reproducibility if cfg['General']['mode'] != 'train': set_seed(cfg['General']['seed']) @@ -73,7 +73,7 @@ def main(cfg): else: # Test the model using the latest checkpoints - latest_checkpoint_path = max(cfg['General']['log_path'].glob('*.ckpt'), key=os.path.getctime) + latest_checkpoint_path = max(cfg.log_path.glob('*.ckpt'), key=os.path.getctime) print(f'Testing with checkpoint: {latest_checkpoint_path}') loaded_model = model.load_from_checkpoint(checkpoint_path=latest_checkpoint_path, Data=cfg['Data']) trainer.test(model=loaded_model, datamodule=data_module) @@ -104,6 +104,6 @@ def main(cfg): cfg.resume = args.resume # Run the main function - main(cfg) + model_run(cfg) diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..277b8b3 --- /dev/null +++ b/pipeline.py @@ -0,0 +1,95 @@ +import os +import glob +import yaml +import argparse +from preprocessing.tiling.main_create_tiles import tile_slide_images +from preprocessing.feature_extraction.get_features import extract_features +from aggregator_train_val.model_run import model_run +from aggregator_train_val.utils import read_yaml + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='End to end pipeline') + parser.add_argument('--tiling', action='store_true', help='Whether to perform tiling') + parser.add_argument('--feature_extraction', action='store_true', help='Whether to perform feature extraction') + parser.add_argument('--model_run', action='store_true', help='Whether to perform training/testing') + args, unknown = parser.parse_known_args() + if not (args.tiling or args.feature_extraction or args.model_run): + parser.error("At least one of --tiling, --feature_extraction, or --model_run must be True") + + if args.tiling: + tiling_group = parser.add_argument_group('Tiling arguments') + tiling_group.add_argument('--slide_dir', type=str, help='path to the source slide image (.svs) directory') + tiling_group.add_argument('--slide_list', type=str, help='path to the source slide image list (.txt) to be processed') + tiling_group.add_argument('--tile_savedir', type=str, default='./tiles/', help='path to the save directory') + + if args.feature_extraction: + feature_group = parser.add_argument_group('Feature extraction arguments') + feature_group.add_argument('--tile_dir', type=str, default=None, + help='path to the tile folder (.txt). If not provided, the tile folder will be generated from the tiling results') + feature_group.add_argument('--batchsize', type=int, default=768, help='batch size for inference') + feature_group.add_argument('--feature_dir', type=str, default='./features/', help='path to the save directory') + + if args.model_run: + model_group = parser.add_argument_group('Model run arguments') + model_group.add_argument('--dataset', type=str, default=None, + help='Path to the dataset directory. If not provided, the dataset will be generated from the feature extraction results') + model_group.add_argument('--label', type=str, default='./labels/labels.csv', + help='Path to the slide label CSV file, which should contain columns including slide, family, probability vector, age, and location') + model_group.add_argument('--split', type=str, default=None, + help='Path to the dataset split file (YAML) containing train and test slide IDs, structured as {"train": [slide_id], "test": [slide_id]}. If not provided, the file will be generated from the dataset') + model_group.add_argument('--mode', type=str, default='train', help='Operation mode: train or test') + model_group.add_argument('--exp_name', type=str, default='default_exp', help='Identifier for the experiment') + model_group.add_argument('--output_dir', type=str, default='./predictions', help='Directory to save predictions') + model_group.add_argument('--resume', action='store_true', help='Resume training from the latest checkpoint') + model_group.add_argument('--config', type=str, default='./aggregator_train_val/config.yaml', help='Path to configuration file') + + return parser.parse_args() + + +if __name__ == '__main__': + # Run the pipeline + args = parse_arguments() + + if args.tiling: + tile_slide_images(source_dir=args.slide_dir, source_list=args.slide_list, save_dir=args.tile_savedir) + + if args.feature_extraction: + if args.tile_dir is None: + try: + with open('./tile_list.txt', 'w') as f: + for item in glob.glob(os.path.join(args.tile_savedir, 'tiles','*')): + f.write("%s\n" % item) + args.tile_dir = './tile_list.txt' + except: + print('No tile folder found, please provide the tile folder path') + exit() + + extract_features(split=args.tile_dir, batchsize=args.batchsize, feature_dir=args.feature_dir) + + if args.model_run: + if args.dataset is None: + try: + args.dataset = os.path.join(args.feature_dir, 'pt_files') + except: + print('No dataset found, please provide the dataset path') + exit() + if args.split is None: + try: + testset = os.listdir(args.dataset) + testset = [os.path.splitext(item)[0] for item in testset] + with open('./split.yaml', 'w') as f: + yaml.dump({'train': [], 'test': testset}, f) + args.split = './split.yaml' + except: + print('No split file found, please provide the split file path') + exit() + cfg = read_yaml(args.config) + cfg['Data']['data_dir'] = args.dataset + cfg['Data']['data_split'] = args.split + cfg['Data']['label_file'] = args.label + cfg['General']['mode'] = args.mode + cfg['Model']['exp_name'] = args.exp_name + cfg['Model']['preds_save'] = args.output_dir + cfg['resume'] = args.resume + model_run(cfg) \ No newline at end of file diff --git a/preprocessing/feature_extraction/get_features.py b/preprocessing/feature_extraction/get_features.py index 808ab14..6239e5b 100644 --- a/preprocessing/feature_extraction/get_features.py +++ b/preprocessing/feature_extraction/get_features.py @@ -8,7 +8,6 @@ import numpy as np from torchvision import transforms from torch.utils.data import Dataset -from timm.models.vision_transformer import VisionTransformer # imagenet normalization @@ -66,16 +65,12 @@ def save_hdf5(output_path, asset_dict, attr_dict=None, mode='w'): return output_path -@click.command() -@click.option('--split', type=str, help='path to the split file (.txt)') -@click.option('--batchsize', type=int, help='batch size for inference') -@click.option('--feature_dir', type=str, help='path to the save directory') -# @click.option('--ckpt', type=str, help='path to the save directory') -def inference(split, batchsize, feature_dir): +def extract_features(split, batchsize=768, feature_dir='./features'): slide_ls = [line.rstrip('\n') for line in open(split)] os.remove(split) # remove the split file after reading test_datat=roi_dataset(slide_ls) database_loader = torch.utils.data.DataLoader(test_datat, batch_size=batchsize, shuffle=False) + # change the name to the model you want to use here os.environ['HF_HOME'] = './model_cache' model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True) @@ -113,6 +108,18 @@ def inference(split, batchsize, feature_dir): features = file['features'][:] features = torch.from_numpy(features) torch.save(features, os.path.join(feature_dir, 'pt_files', os.path.basename(h5file)+'.pt')) + file.close() + + print('Feature extraction done!') + +@click.command() +@click.option('--split', type=str, help='path to the split file (.txt)') +@click.option('--batchsize', type=int, help='batch size for inference') +@click.option('--feature_dir', type=str, help='path to the save directory') +# @click.option('--ckpt', type=str, help='path to the save directory') +def inference(split, batchsize, feature_dir): + extract_features(split, batchsize, feature_dir) + if __name__ == '__main__': inference() \ No newline at end of file diff --git a/preprocessing/tiling/main_create_tiles.py b/preprocessing/tiling/main_create_tiles.py index 39456fc..71075b6 100644 --- a/preprocessing/tiling/main_create_tiles.py +++ b/preprocessing/tiling/main_create_tiles.py @@ -1,18 +1,10 @@ import os import glob import click -from slide_lib import segment_tiling +from .slide_lib import segment_tiling -@click.command() -@click.option('--source_dir', type=str, help='path to the source slide image (.svs) directory') -@click.option('--source_list', type=str, help='path to the source slide image (.svs) list to be processed') -@click.option('--save_dir', type=str, help='path to the save directory') -@click.option('--patch_size', type=int, default=256, help='patch size') -@click.option('--step_size', type=int, default=256, help='step size') -@click.option('--mag', type=int, default=20, help='magnification for patch extraction') -@click.option('--index', type=int, default=None) -def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size: int, step_size: int, mag: int, index: int) -> None: +def tile_slide_images(source_dir: str, source_list: list, save_dir: str, patch_size: int=256, step_size: int=256, mag: int=20, index: int=0) -> None: """ Tile whole slide images stored in the .svs/.ndpi/.scn format at the desired magnification. @@ -21,7 +13,7 @@ def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size: source_dir : str Path to the source slide image (.svs) directory. source_list : str - Path to the source slide image (.svs) list to be processed. + Path to the source slide image list (.txt) to be processed. save_dir : str Path to the save directory. patch_size : int @@ -58,6 +50,17 @@ def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size: total_time = segment_tiling(**directories, patch_size=patch_size, mag_level=mag, step_size= step_size, index=index) print(f"The average processing time for each slide is {total_time:.2f} seconds.") +@click.command() +@click.option('--source_dir', type=str, help='path to the source slide image (.svs) directory') +@click.option('--source_list', type=str, help='path to the source slide image (.svs) list to be processed') +@click.option('--save_dir', type=str, help='path to the save directory') +@click.option('--patch_size', type=int, default=256, help='patch size') +@click.option('--step_size', type=int, default=256, help='step size') +@click.option('--mag', type=int, default=20, help='magnification for patch extraction') +@click.option('--index', type=int, default=0) +def generate_tiles(source_dir, source_list, save_dir, patch_size, step_size, mag, index): + tile_slide_images(source_dir, source_list, save_dir, patch_size, step_size, mag, index) + if __name__ == '__main__': - batch_tiling() + generate_tiles() diff --git a/preprocessing/tiling/slide_lib/__init__.py b/preprocessing/tiling/slide_lib/__init__.py index c046a00..9e7137d 100644 --- a/preprocessing/tiling/slide_lib/__init__.py +++ b/preprocessing/tiling/slide_lib/__init__.py @@ -1,2 +1,3 @@ from .utils import * -from .constants_color import * \ No newline at end of file +from .constants_color import * +from .segment_patching import * \ No newline at end of file diff --git a/preprocessing/tiling/slide_lib/segment_patching.py b/preprocessing/tiling/slide_lib/segment_patching.py index 6928630..03e6017 100644 --- a/preprocessing/tiling/slide_lib/segment_patching.py +++ b/preprocessing/tiling/slide_lib/segment_patching.py @@ -6,7 +6,7 @@ import pandas as pd from PIL import Image import multiprocessing as mp -from slide_lib import * +from .utils import * def segment(wsi: openslide.OpenSlide)->tuple[list, list, Image.Image, float]: diff --git a/preprocessing/tiling/slide_lib/utils.py b/preprocessing/tiling/slide_lib/utils.py index 2d3e1f1..1ea4765 100644 --- a/preprocessing/tiling/slide_lib/utils.py +++ b/preprocessing/tiling/slide_lib/utils.py @@ -8,7 +8,7 @@ from typing import Union import skimage.filters as sk_filters from PIL import Image -from slide_lib import PENS_RGB +from .constants_color import PENS_RGB import multiprocessing as mp import concurrent.futures