-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsade_registration.py
204 lines (169 loc) · 7.07 KB
/
sade_registration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import glob
import os
import re
import sys
import ants
import numpy as np
from absl import app, flags
from ml_collections.config_flags import config_flags
from sade.datasets.loaders import get_image_files_list
from sade.datasets.transforms import get_val_transform
from tqdm import tqdm
CACHE_DIR = "/codespace/braintypicality/dataset/template_cache/"
# DATADIR = "/BEE/Connectome/ABCD/"
DATADIR = "/DATA/"
procd_ref_img_path = f"{CACHE_DIR}/cropped_niral_mni.nii.gz"
####
if not os.path.exists(procd_ref_img_path):
T1_REF_IMG_PATH = os.path.join(
CACHE_DIR, "mni_icbm152_09a/mni_icbm152_t1_tal_nlin_sym_09a.nrrd"
)
T2_REF_IMG_PATH = os.path.join(
CACHE_DIR, "mni_icbm152_09a/mni_icbm152_t2_tal_nlin_sym_09a.nrrd"
)
MASK_REF_IMG_PATH = os.path.join(
CACHE_DIR, "mni_icbm152_09a/mni_icbm152_t1_tal_nlin_sym_09a_mask.nrrd"
)
ants_mni = ants.image_read(f"{CACHE_DIR}/croppedMni152.nii.gz")
t1_ref_img = ants.image_read(T1_REF_IMG_PATH)
t2_ref_img = ants.image_read(T2_REF_IMG_PATH)
ref_img_mask = ants.image_read(MASK_REF_IMG_PATH)
# Use ANTs' tighter cropping
diff = np.array(t1_ref_img.shape) - np.array(ants_mni.shape)
crop_idxs_start, crop_idxs_end = (
1 + diff // 2,
np.array(t1_ref_img.shape) - diff // 2,
)
t1_ref_img = ants.crop_indices(t1_ref_img, crop_idxs_start, crop_idxs_end)
t2_ref_img = ants.crop_indices(t2_ref_img, crop_idxs_start, crop_idxs_end)
ref_img_mask = ants.crop_indices(ref_img_mask, crop_idxs_start, crop_idxs_end)
procd_ref_img = ants.merge_channels(
(t1_ref_img * ref_img_mask, t2_ref_img * ref_img_mask)
)
procd_ref_img.to_filename(procd_ref_img_path)
####
def run_registration(config, dataset_name):
savedir = f"/{DATADIR}/Users/amahmood/braintyp/spacing_{int(config.data.spacing_pix_dim)}-ants"
os.makedirs(savedir, exist_ok=True)
img_loader = get_val_transform(config)
procd_ref_img_path = f"{CACHE_DIR}/cropped_niral_mni.nii.gz"
ref_img_tensor = img_loader({"image": procd_ref_img_path})["image"].numpy()
ref_img_post_transform = (ants.from_numpy(ref_img_tensor[0]) + 1) / 2
fnames = get_image_files_list(
dataset_name, config.data.dir_path, config.data.splits_dir
)
for fname_dict in tqdm(fnames):
img_tensor = img_loader(fname_dict)["image"].numpy()
t1_img = (ants.from_numpy(img_tensor[0]) + 1) / 2
fname = fname_dict["image"]
sampleid = os.path.basename(fname).split(".nii.gz")[0]
if config.data.spacing_pix_dim == 1.0:
# reg_iterations = [40, 20, 10, 0]
reg_iterations = [500, 80, 40, 0]
else:
reg_iterations = [40, 20, 0]
_ = ants.registration(
fixed=ref_img_post_transform,
moving=t1_img,
# More info at https://github.com/ANTsX/ANTs/blob/master/Scripts/antsRegistrationSyN.sh
type_of_transform="antsRegistrationSyN[s]",
outprefix=f"{savedir}/{sampleid}",
write_composite_transform=True,
reg_iterations=reg_iterations,
verbose=0,
)
def extract_id(fname):
return re.match("(.*)(.nii.gz|.npz)$", fname).group(1)
def apply_registration(config, load_dir, save_dir):
os.makedirs(save_dir, exist_ok=True)
transforms_dir = f"/{DATADIR}/Users/amahmood/braintyp/spacing_{int(config.data.spacing_pix_dim)}-ants"
procd_ref_img_path = f"{CACHE_DIR}/cropped_niral_mni.nii.gz"
img_loader = get_val_transform(config)
ref_img_tensor = img_loader({"image": procd_ref_img_path})["image"].numpy()
ref_img_post_transform = (ants.from_numpy(ref_img_tensor[0]) + 1) / 2
print(
f"Applying registrations from {transforms_dir} to images in {load_dir} and saving to {save_dir}."
)
paths = glob.glob(os.path.join(load_dir, "*"))
for fname in tqdm(paths):
if ".npz" in fname:
x = np.load(fname)["heatmap"]
x = ants.from_numpy(x)
elif ".nii.gz" in fname:
x = ants.image_read(fname)
else:
raise ValueError(
f"File format of {fname} not recognized - must be .nii.gz or .npz."
)
sid = extract_id(os.path.basename(fname))
transform_mat = f"{transforms_dir}/{sid}Composite.h5"
h_aligned = ants.apply_transforms(
fixed=ref_img_post_transform,
interpolator="linear",
verbose=False,
moving=x,
transformlist=transform_mat,
)
h_aligned.to_filename(f"{save_dir}/{sid}.nii.gz")
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Sade configuration used to train the model.", lock_config=True
)
flags.DEFINE_enum(
"mode",
None,
["compute", "apply"],
"Whether to compute the registrations and save them or apply them from saved directory.",
)
flags.DEFINE_string(
"dataset",
None,
"The dataset for which to run the registration (should match names used in config and splits dir).",
)
flags.DEFINE_string(
"load_dir",
None,
"Directory containing the images to register - only used in apply mode.",
required=False,
)
flags.DEFINE_string(
"save_dir",
None,
"Directory to save the registered images - only used in apply mode.",
required=False,
)
def main(argv):
if FLAGS.mode == "compute":
config = FLAGS.config
config.data.splits_dir = (
"/ASD/ahsan_projects/Developer/braintypicality-scripts/split-keys/"
)
run_registration(FLAGS.config, FLAGS.dataset)
elif FLAGS.mode == "apply":
assert FLAGS.load_dir is not None, "load_dir must be provided in apply mode."
assert FLAGS.save_dir is not None, "save_dir must be provided in apply mode."
apply_registration(FLAGS.config, FLAGS.load_dir, FLAGS.save_dir)
else:
raise ValueError(f"Mode {FLAGS.mode} not recognized.")
"""
This script is used to compute and apply registrations to the voxelwise heatmaps using ANTs.
It is recommended to use this script as the data loader of sade crops the original images,
and therefore there is a mismatch in the image sizes between the heatmaps and the original images.
This script computes registrations on cropped input images using loaders from `sade.datasets`.
Alternatively, one could use their own registration script but should make sure to
pad the heatmaps to the original image size before applying the registrations.
Example usage:
# This will compute the registrations for the conte dataset
# and save them to the transforms directory specified in the code.
python sade_registration.py --mode compute \
--config /codespace/sade/sade/configs/ve/biggan_config.py \
--dataset conte
# This will apply the registrations from the transforms directory
python sade_registration.py --mode apply \
--config /codespace/sade/sade/configs/ve/biggan_config.py \
--dataset conte \
--load_dir /ASD/ahsan_projects/braintypicality/workdir/cuda_opt/learnable/experiments/reprod-correct/conte \
--save_dir /ASD/ahsan_projects/Developer/ds-analysis/ebds/registered-heatmaps/
"""
if __name__ == "__main__":
app.run(main)