Skip to content

Commit

Permalink
all module files added
Browse files Browse the repository at this point in the history
  • Loading branch information
DaruiJin committed Nov 20, 2024
1 parent 9d28522 commit 12c0bb8
Show file tree
Hide file tree
Showing 8 changed files with 822 additions and 0 deletions.
118 changes: 118 additions & 0 deletions preprocessing/feature_extraction/get_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import h5py
import click
import glob
import torch
import timm
from PIL import Image
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from timm.models.vision_transformer import VisionTransformer


# imagenet normalization
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trnsfrms_val = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
]
)

class roi_dataset(Dataset):
def __init__(self, slide_ls):
super().__init__()
self.slide_ls = slide_ls
self.tile_ls = []
for slide in self.slide_ls:
self.tile_ls.extend(glob.glob(os.path.join(slide, '*.jpg')))
self.transform = trnsfrms_val

def __len__(self):
return len(self.tile_ls)

def __getitem__(self, idx):
slide_id = self.tile_ls[idx].split('/')[-2]
image = Image.open(self.tile_ls[idx]).convert('RGB')
image = self.transform(image)
spatial_x = int(self.tile_ls[idx].split('/')[-1].split('_')[-2])
spatial_y = int(self.tile_ls[idx].split('/')[-1].split('_')[-1].split('.')[0])
return image, slide_id, spatial_x, spatial_y


def save_hdf5(output_path, asset_dict, attr_dict=None, mode='w'):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
file = h5py.File(output_path, mode)
for key, val in asset_dict.items():
data_shape = val.shape
if key not in file:
data_type = val.dtype
chunk_shape = (1, ) + data_shape[1:]
maxshape = (None, ) + data_shape[1:]
dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type)
dset[:] = val
if attr_dict is not None:
if key in attr_dict.keys():
for attr_key, attr_val in attr_dict[key].items():
dset.attrs[attr_key] = attr_val
else:
dset = file[key]
dset.resize(len(dset) + data_shape[0], axis=0)
dset[-data_shape[0]:] = val
file.close()
return output_path


@click.command()
@click.option('--split', type=str, help='path to the split file (.txt)')
@click.option('--batchsize', type=int, help='batch size for inference')
@click.option('--feature_dir', type=str, help='path to the save directory')
# @click.option('--ckpt', type=str, help='path to the save directory')
def inference(split, batchsize, feature_dir):
slide_ls = [line.rstrip('\n') for line in open(split)]
os.remove(split) # remove the split file after reading
test_datat=roi_dataset(slide_ls)
database_loader = torch.utils.data.DataLoader(test_datat, batch_size=batchsize, shuffle=False)
# change the name to the model you want to use here
os.environ['HF_HOME'] = './model_cache'
model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
model.cuda()
model.eval()

count = 0
print('Inference begins...')
with torch.no_grad():
for batch, slide_id, spatial_x, spatial_y in database_loader:
print(f'{count}/{len(database_loader)}')
batch = batch.cuda()

features = model(batch)
features = features.cpu().numpy()
id_set = list(np.unique(np.array(slide_id)))
spatial_x = np.array(spatial_x)
spatial_y = np.array(spatial_y)
for id in id_set:
feature = features[np.array(slide_id)==id]
pos_x = spatial_x[np.array(slide_id)==id]
pos_y = spatial_y[np.array(slide_id)==id]
output_path = os.path.join(feature_dir, 'h5_files', id+'.h5')
asset_dict = {'features': feature, 'pos_x': pos_x, 'pos_y': pos_y}
save_hdf5(output_path, asset_dict, attr_dict=None, mode='a')
count += 1

h5_ls = [os.path.join(feature_dir, 'h5_files', item.split('/')[-1]) for item in slide_ls]
os.makedirs(os.path.join(feature_dir, 'pt_files'), exist_ok=True)
for idx, h5file in enumerate(h5_ls):
if os.path.exists(os.path.join(feature_dir, 'pt_files', os.path.basename(h5file)+'.pt')):
pass
else:
file = h5py.File(h5file+'.h5', "r")
features = file['features'][:]
features = torch.from_numpy(features)
torch.save(features, os.path.join(feature_dir, 'pt_files', os.path.basename(h5file)+'.pt'))

if __name__ == '__main__':
inference()
39 changes: 39 additions & 0 deletions preprocessing/feature_extraction/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import subprocess
import time
import os
import glob
import numpy as np


dataset = 'datasetname'
tile_dir = f'/path/to/tile/images/{dataset}/20x_256/tiles' # Change this to the path of the tiles
save_dir = f'/save/path/ProvGigaPath_256_1536_{dataset}' # Change this to the path where you want to save the embeddings
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, 'pt_files'), exist_ok=True)
os.makedirs(os.path.join(save_dir, 'h5_files'), exist_ok=True)

existing_files = os.listdir(os.path.join(save_dir, 'pt_files'))
existing_files = [item.split('.')[0]for item in existing_files]
slide_list_ = glob.glob(os.path.join(tile_dir, '*'))
slide_list = [item for item in slide_list_ if os.path.basename(item) not in existing_files]

sub_num = 50 # Number of slides to process in one job
job_num = int(np.ceil(len(slide_list)/sub_num)) # Number of jobs to submit
for i in range(job_num):
start = i*sub_num
end = (i+1)*sub_num if (i+1)*sub_num < len(slide_list) else len(slide_list)
slide_sub_list = slide_list[start:end] # List of slides to process in this job
list_loc_tmp = f'./{dataset}_list_{i}.txt'
with open(list_loc_tmp, 'w') as f:
for item in slide_sub_list:
f.write("%s\n" % item)

batchsize = 768
cmd = f"python -W ignore get_features.py --split '{list_loc_tmp}' --batchsize {batchsize} --feature_dir {save_dir}"
bsub_cmd = f'bsub -gpu num=1:j_exclusive=yes:gmem=23.5G -R "rusage[mem=20G]" -L /bin/bash -q gpu -J {dataset}_{i} -o ./log_{i}.log -e ./log_{i}.err "source ~/.bashrc && {cmd}"'
try:
subprocess.run(bsub_cmd, shell=True)
time.sleep(3)
except subprocess.CalledProcessError as e:
print(f"An error occurred while submitting job {i}: {e}")

63 changes: 63 additions & 0 deletions preprocessing/tiling/main_create_tiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import glob
import click
from slide_lib import segment_tiling


@click.command()
@click.option('--source_dir', type=str, help='path to the source slide image (.svs) directory')
@click.option('--source_list', type=str, help='path to the source slide image (.svs) list to be processed')
@click.option('--save_dir', type=str, help='path to the save directory')
@click.option('--patch_size', type=int, default=256, help='patch size')
@click.option('--step_size', type=int, default=256, help='step size')
@click.option('--mag', type=int, default=20, help='magnification for patch extraction')
@click.option('--index', type=int, default=None)
def batch_tiling(source_dir: str, source_list: list, save_dir: str, patch_size: int, step_size: int, mag: int, index: int) -> None:
"""
Tile whole slide images stored in the .svs/.ndpi/.scn format at the desired magnification.
Parameters
----------
source_dir : str
Path to the source slide image (.svs) directory.
source_list : str
Path to the source slide image (.svs) list to be processed.
save_dir : str
Path to the save directory.
patch_size : int
Patch size.
step_size : int
Step size.
mag : int
Magnification for patch extraction.
Returns
-------
None
"""

if source_list:
slide_list = [line.rstrip('\n') for line in open(source_list)]
os.remove(source_list)

tile_save_dir = os.path.join(save_dir, 'tiles')
mask_save_dir = os.path.join(save_dir, 'masks')
stitch_save_dir = os.path.join(save_dir, 'stitches')

directories = {'source': slide_list if source_list else glob.glob(source_dir),
'save_dir': save_dir,
'tile_save_dir': tile_save_dir,
'mask_save_dir': mask_save_dir,
'stitch_save_dir': stitch_save_dir}

for key, val in directories.items():
if key == 'source':
continue
os.makedirs(val, exist_ok=True)

total_time = segment_tiling(**directories, patch_size=patch_size, mag_level=mag, step_size= step_size, index=index)
print(f"The average processing time for each slide is {total_time:.2f} seconds.")


if __name__ == '__main__':
batch_tiling()
46 changes: 46 additions & 0 deletions preprocessing/tiling/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import subprocess
import time
import os
import numpy as np


dataset = 'dataset_name'
slide_dir = f'/path/to/slide/images/{dataset}' # Change this to the path of the slide images
save_dir = f'/save/path/{dataset}/20x_256' # Change this to the path where you want to save the tiles/masks/stitches
os.makedirs(save_dir, exist_ok=True)
os.makedirs('./log_file/', exist_ok=True)

try:
existing_files = os.listdir(os.path.join(save_dir, 'tiles')) # Check if the slide has already been processed
except FileNotFoundError:
existing_files = []

formats = ['.svs', '.ndpi', '.scn'] # Add more formats if needed
slide_list = []
for root, dirs, files in os.walk(slide_dir):
for file in files:
if any(file.endswith(fmt) for fmt in formats):
if os.path.splitext(file)[0] not in existing_files:
slide_list.append(os.path.join(root, file))

sub_num = 5 # Number of slides to process in one job
job_num = int(np.ceil(len(slide_list)/sub_num)) # Number of jobs to submit
# tiling_params = {'patch_size': 256, 'step_size': 256, 'mag': 20} # Parameters for tiling, change if needed

for i in range(job_num):
start = i*sub_num
end = (i+1)*sub_num if (i+1)*sub_num < len(slide_list) else len(slide_list)
slide_sub_list = slide_list[start:end]
# slide_sub_list = '|'.join(slide_sub_list)
list_loc_tmp = f'./{dataset}_list_{i}.txt'
with open(list_loc_tmp, 'w') as f: # record the slides to be processed in a txt file
for item in slide_sub_list:
f.write("%s\n" % item)

cmd = f'python -W ignore main_create_tiles.py --index {i} --source_list {list_loc_tmp} --save_dir {save_dir} --patch_size 256 --step_size 256 --mag 20'
bsub_cmd = f'bsub -R "rusage[mem=30G]" -J {dataset}_{i} -q long -o ./log_{i}.out -e ./log_{i}.err {cmd}'
try:
subprocess.run(bsub_cmd, shell=True)
time.sleep(1)
except subprocess.CalledProcessError as e:
print(f"An error occurred while submitting job {i}: {e}")
2 changes: 2 additions & 0 deletions preprocessing/tiling/slide_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .utils import *
from .constants_color import *
44 changes: 44 additions & 0 deletions preprocessing/tiling/slide_lib/constants_color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
PENS_RGB = {
"red": [
(150, 80, 90),
(110, 20, 30),
(185, 65, 105),
(195, 85, 125),
(220, 115, 145),
(125, 40, 70),
(200, 120, 150),
(100, 50, 65),
(85, 25, 45),
],
"green": [
(150, 160, 140),
(70, 110, 110),
(45, 115, 100),
(30, 75, 60),
(195, 220, 210),
(225, 230, 225),
(170, 210, 200),
(20, 30, 20),
(50, 60, 40),
(30, 50, 35),
(65, 70, 60),
(100, 110, 105),
(165, 180, 180),
(140, 140, 150),
(185, 195, 195),
],
"blue": [
(60, 120, 190),
(120, 170, 200),
(175, 210, 230),
(145, 210, 210),
(37, 95, 160),
(30, 65, 130),
(130, 155, 180),
(40, 35, 85),
(30, 20, 65),
(90, 90, 140),
(60, 60, 120),
(110, 110, 175),
],
}
Loading

0 comments on commit 12c0bb8

Please sign in to comment.