-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
822 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os | ||
import h5py | ||
import click | ||
import glob | ||
import torch | ||
import timm | ||
from PIL import Image | ||
import numpy as np | ||
from torchvision import transforms | ||
from torch.utils.data import Dataset | ||
from timm.models.vision_transformer import VisionTransformer | ||
|
||
|
||
# imagenet normalization | ||
mean = (0.485, 0.456, 0.406) | ||
std = (0.229, 0.224, 0.225) | ||
trnsfrms_val = transforms.Compose( | ||
[ | ||
transforms.Resize(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=mean, std=std) | ||
] | ||
) | ||
|
||
class roi_dataset(Dataset): | ||
def __init__(self, slide_ls): | ||
super().__init__() | ||
self.slide_ls = slide_ls | ||
self.tile_ls = [] | ||
for slide in self.slide_ls: | ||
self.tile_ls.extend(glob.glob(os.path.join(slide, '*.jpg'))) | ||
self.transform = trnsfrms_val | ||
|
||
def __len__(self): | ||
return len(self.tile_ls) | ||
|
||
def __getitem__(self, idx): | ||
slide_id = self.tile_ls[idx].split('/')[-2] | ||
image = Image.open(self.tile_ls[idx]).convert('RGB') | ||
image = self.transform(image) | ||
spatial_x = int(self.tile_ls[idx].split('/')[-1].split('_')[-2]) | ||
spatial_y = int(self.tile_ls[idx].split('/')[-1].split('_')[-1].split('.')[0]) | ||
return image, slide_id, spatial_x, spatial_y | ||
|
||
|
||
def save_hdf5(output_path, asset_dict, attr_dict=None, mode='w'): | ||
os.makedirs(os.path.dirname(output_path), exist_ok=True) | ||
file = h5py.File(output_path, mode) | ||
for key, val in asset_dict.items(): | ||
data_shape = val.shape | ||
if key not in file: | ||
data_type = val.dtype | ||
chunk_shape = (1, ) + data_shape[1:] | ||
maxshape = (None, ) + data_shape[1:] | ||
dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type) | ||
dset[:] = val | ||
if attr_dict is not None: | ||
if key in attr_dict.keys(): | ||
for attr_key, attr_val in attr_dict[key].items(): | ||
dset.attrs[attr_key] = attr_val | ||
else: | ||
dset = file[key] | ||
dset.resize(len(dset) + data_shape[0], axis=0) | ||
dset[-data_shape[0]:] = val | ||
file.close() | ||
return output_path | ||
|
||
|
||
@click.command() | ||
@click.option('--split', type=str, help='path to the split file (.txt)') | ||
@click.option('--batchsize', type=int, help='batch size for inference') | ||
@click.option('--feature_dir', type=str, help='path to the save directory') | ||
# @click.option('--ckpt', type=str, help='path to the save directory') | ||
def inference(split, batchsize, feature_dir): | ||
slide_ls = [line.rstrip('\n') for line in open(split)] | ||
os.remove(split) # remove the split file after reading | ||
test_datat=roi_dataset(slide_ls) | ||
database_loader = torch.utils.data.DataLoader(test_datat, batch_size=batchsize, shuffle=False) | ||
# change the name to the model you want to use here | ||
os.environ['HF_HOME'] = './model_cache' | ||
model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True) | ||
model.cuda() | ||
model.eval() | ||
|
||
count = 0 | ||
print('Inference begins...') | ||
with torch.no_grad(): | ||
for batch, slide_id, spatial_x, spatial_y in database_loader: | ||
print(f'{count}/{len(database_loader)}') | ||
batch = batch.cuda() | ||
|
||
features = model(batch) | ||
features = features.cpu().numpy() | ||
id_set = list(np.unique(np.array(slide_id))) | ||
spatial_x = np.array(spatial_x) | ||
spatial_y = np.array(spatial_y) | ||
for id in id_set: | ||
feature = features[np.array(slide_id)==id] | ||
pos_x = spatial_x[np.array(slide_id)==id] | ||
pos_y = spatial_y[np.array(slide_id)==id] | ||
output_path = os.path.join(feature_dir, 'h5_files', id+'.h5') | ||
asset_dict = {'features': feature, 'pos_x': pos_x, 'pos_y': pos_y} | ||
save_hdf5(output_path, asset_dict, attr_dict=None, mode='a') | ||
count += 1 | ||
|
||
h5_ls = [os.path.join(feature_dir, 'h5_files', item.split('/')[-1]) for item in slide_ls] | ||
os.makedirs(os.path.join(feature_dir, 'pt_files'), exist_ok=True) | ||
for idx, h5file in enumerate(h5_ls): | ||
if os.path.exists(os.path.join(feature_dir, 'pt_files', os.path.basename(h5file)+'.pt')): | ||
pass | ||
else: | ||
file = h5py.File(h5file+'.h5', "r") | ||
features = file['features'][:] | ||
features = torch.from_numpy(features) | ||
torch.save(features, os.path.join(feature_dir, 'pt_files', os.path.basename(h5file)+'.pt')) | ||
|
||
if __name__ == '__main__': | ||
inference() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import subprocess | ||
import time | ||
import os | ||
import glob | ||
import numpy as np | ||
|
||
|
||
dataset = 'datasetname' | ||
tile_dir = f'/path/to/tile/images/{dataset}/20x_256/tiles' # Change this to the path of the tiles | ||
save_dir = f'/save/path/ProvGigaPath_256_1536_{dataset}' # Change this to the path where you want to save the embeddings | ||
os.makedirs(save_dir, exist_ok=True) | ||
os.makedirs(os.path.join(save_dir, 'pt_files'), exist_ok=True) | ||
os.makedirs(os.path.join(save_dir, 'h5_files'), exist_ok=True) | ||
|
||
existing_files = os.listdir(os.path.join(save_dir, 'pt_files')) | ||
existing_files = [item.split('.')[0]for item in existing_files] | ||
slide_list_ = glob.glob(os.path.join(tile_dir, '*')) | ||
slide_list = [item for item in slide_list_ if os.path.basename(item) not in existing_files] | ||
|
||
sub_num = 50 # Number of slides to process in one job | ||
job_num = int(np.ceil(len(slide_list)/sub_num)) # Number of jobs to submit | ||
for i in range(job_num): | ||
start = i*sub_num | ||
end = (i+1)*sub_num if (i+1)*sub_num < len(slide_list) else len(slide_list) | ||
slide_sub_list = slide_list[start:end] # List of slides to process in this job | ||
list_loc_tmp = f'./{dataset}_list_{i}.txt' | ||
with open(list_loc_tmp, 'w') as f: | ||
for item in slide_sub_list: | ||
f.write("%s\n" % item) | ||
|
||
batchsize = 768 | ||
cmd = f"python -W ignore get_features.py --split '{list_loc_tmp}' --batchsize {batchsize} --feature_dir {save_dir}" | ||
bsub_cmd = f'bsub -gpu num=1:j_exclusive=yes:gmem=23.5G -R "rusage[mem=20G]" -L /bin/bash -q gpu -J {dataset}_{i} -o ./log_{i}.log -e ./log_{i}.err "source ~/.bashrc && {cmd}"' | ||
try: | ||
subprocess.run(bsub_cmd, shell=True) | ||
time.sleep(3) | ||
except subprocess.CalledProcessError as e: | ||
print(f"An error occurred while submitting job {i}: {e}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import glob | ||
import click | ||
from slide_lib import segment_tiling | ||
|
||
|
||
@click.command() | ||
@click.option('--source_dir', type=str, help='path to the source slide image (.svs) directory') | ||
@click.option('--source_list', type=str, help='path to the source slide image (.svs) list to be processed') | ||
@click.option('--save_dir', type=str, help='path to the save directory') | ||
@click.option('--patch_size', type=int, default=256, help='patch size') | ||
@click.option('--step_size', type=int, default=256, help='step size') | ||
@click.option('--mag', type=int, default=20, help='magnification for patch extraction') | ||
@click.option('--index', type=int, default=None) | ||
def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size: int, step_size: int, mag: int, index: int) -> None: | ||
""" | ||
Tile whole slide images stored in the .svs/.ndpi/.scn format at the desired magnification. | ||
Parameters | ||
---------- | ||
source_dir : str | ||
Path to the source slide image (.svs) directory. | ||
source_list : str | ||
Path to the source slide image (.svs) list to be processed. | ||
save_dir : str | ||
Path to the save directory. | ||
patch_size : int | ||
Patch size. | ||
step_size : int | ||
Step size. | ||
mag : int | ||
Magnification for patch extraction. | ||
Returns | ||
------- | ||
None | ||
""" | ||
|
||
if source_list: | ||
slide_list = [line.rstrip('\n') for line in open(source_list)] | ||
os.remove(source_list) | ||
|
||
tile_save_dir = os.path.join(save_dir, 'tiles') | ||
mask_save_dir = os.path.join(save_dir, 'masks') | ||
stitch_save_dir = os.path.join(save_dir, 'stitches') | ||
|
||
directories = {'source': slide_list if source_list else glob.glob(source_dir), | ||
'save_dir': save_dir, | ||
'tile_save_dir': tile_save_dir, | ||
'mask_save_dir': mask_save_dir, | ||
'stitch_save_dir': stitch_save_dir} | ||
|
||
for key, val in directories.items(): | ||
if key == 'source': | ||
continue | ||
os.makedirs(val, exist_ok=True) | ||
|
||
total_time = segment_tiling(**directories, patch_size=patch_size, mag_level=mag, step_size= step_size, index=index) | ||
print(f"The average processing time for each slide is {total_time:.2f} seconds.") | ||
|
||
|
||
if __name__ == '__main__': | ||
batch_tiling() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import subprocess | ||
import time | ||
import os | ||
import numpy as np | ||
|
||
|
||
dataset = 'dataset_name' | ||
slide_dir = f'/path/to/slide/images/{dataset}' # Change this to the path of the slide images | ||
save_dir = f'/save/path/{dataset}/20x_256' # Change this to the path where you want to save the tiles/masks/stitches | ||
os.makedirs(save_dir, exist_ok=True) | ||
os.makedirs('./log_file/', exist_ok=True) | ||
|
||
try: | ||
existing_files = os.listdir(os.path.join(save_dir, 'tiles')) # Check if the slide has already been processed | ||
except FileNotFoundError: | ||
existing_files = [] | ||
|
||
formats = ['.svs', '.ndpi', '.scn'] # Add more formats if needed | ||
slide_list = [] | ||
for root, dirs, files in os.walk(slide_dir): | ||
for file in files: | ||
if any(file.endswith(fmt) for fmt in formats): | ||
if os.path.splitext(file)[0] not in existing_files: | ||
slide_list.append(os.path.join(root, file)) | ||
|
||
sub_num = 5 # Number of slides to process in one job | ||
job_num = int(np.ceil(len(slide_list)/sub_num)) # Number of jobs to submit | ||
# tiling_params = {'patch_size': 256, 'step_size': 256, 'mag': 20} # Parameters for tiling, change if needed | ||
|
||
for i in range(job_num): | ||
start = i*sub_num | ||
end = (i+1)*sub_num if (i+1)*sub_num < len(slide_list) else len(slide_list) | ||
slide_sub_list = slide_list[start:end] | ||
# slide_sub_list = '|'.join(slide_sub_list) | ||
list_loc_tmp = f'./{dataset}_list_{i}.txt' | ||
with open(list_loc_tmp, 'w') as f: # record the slides to be processed in a txt file | ||
for item in slide_sub_list: | ||
f.write("%s\n" % item) | ||
|
||
cmd = f'python -W ignore main_create_tiles.py --index {i} --source_list {list_loc_tmp} --save_dir {save_dir} --patch_size 256 --step_size 256 --mag 20' | ||
bsub_cmd = f'bsub -R "rusage[mem=30G]" -J {dataset}_{i} -q long -o ./log_{i}.out -e ./log_{i}.err {cmd}' | ||
try: | ||
subprocess.run(bsub_cmd, shell=True) | ||
time.sleep(1) | ||
except subprocess.CalledProcessError as e: | ||
print(f"An error occurred while submitting job {i}: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .utils import * | ||
from .constants_color import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
PENS_RGB = { | ||
"red": [ | ||
(150, 80, 90), | ||
(110, 20, 30), | ||
(185, 65, 105), | ||
(195, 85, 125), | ||
(220, 115, 145), | ||
(125, 40, 70), | ||
(200, 120, 150), | ||
(100, 50, 65), | ||
(85, 25, 45), | ||
], | ||
"green": [ | ||
(150, 160, 140), | ||
(70, 110, 110), | ||
(45, 115, 100), | ||
(30, 75, 60), | ||
(195, 220, 210), | ||
(225, 230, 225), | ||
(170, 210, 200), | ||
(20, 30, 20), | ||
(50, 60, 40), | ||
(30, 50, 35), | ||
(65, 70, 60), | ||
(100, 110, 105), | ||
(165, 180, 180), | ||
(140, 140, 150), | ||
(185, 195, 195), | ||
], | ||
"blue": [ | ||
(60, 120, 190), | ||
(120, 170, 200), | ||
(175, 210, 230), | ||
(145, 210, 210), | ||
(37, 95, 160), | ||
(30, 65, 130), | ||
(130, 155, 180), | ||
(40, 35, 85), | ||
(30, 20, 65), | ||
(90, 90, 140), | ||
(60, 60, 120), | ||
(110, 110, 175), | ||
], | ||
} |
Oops, something went wrong.