Skip to content

Commit

Permalink
update modality models: load from json; add new model: predict from n…
Browse files Browse the repository at this point in the history
…ormalized intensities
  • Loading branch information
wasserth committed Feb 24, 2025
1 parent 3a763a9 commit b709255
Show file tree
Hide file tree
Showing 16 changed files with 91 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Master
* make `totalseg_get_modality` work with normalized intensities (images which do not have original HU values anymore)


## Release 2.7.0
Expand Down
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@
package_data={"totalsegmentator":
["resources/totalsegmentator_snomed_mapping.csv",
"resources/contrast_phase_classifiers_2024_07_19.pkl",
"resources/modality_classifiers_2024_10_04.pkl",
"resources/modality_classifiers_2025_02_24.json.0",
"resources/modality_classifiers_2025_02_24.json.1",
"resources/modality_classifiers_2025_02_24.json.2",
"resources/modality_classifiers_2025_02_24.json.3",
"resources/modality_classifiers_2025_02_24.json.4",
"resources/modality_classifiers_normalized_2025_02_24.json.0",
"resources/modality_classifiers_normalized_2025_02_24.json.1",
"resources/modality_classifiers_normalized_2025_02_24.json.2",
"resources/modality_classifiers_normalized_2025_02_24.json.3",
"resources/modality_classifiers_normalized_2025_02_24.json.4",
"resources/ct_brain_atlas_1mm.nii.gz"]
},
install_requires=[
Expand Down
66 changes: 63 additions & 3 deletions totalsegmentator/bin/totalseg_get_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import pickle
from pprint import pprint
import pkg_resources
import xgboost as xgb

import nibabel as nib
import numpy as np

from totalsegmentator.python_api import totalsegmentator
from totalsegmentator.config import send_usage_stats_application

"""
Expand All @@ -29,6 +31,7 @@ def get_features(nifti_img):
return [mean, std, min, max]


# only use image level intensity features. Faster and high accuracy if image intensities not normalized (original HU values)
def get_modality(img: nib.Nifti1Image):
"""
Predict modality
Expand All @@ -41,8 +44,58 @@ def get_modality(img: nib.Nifti1Image):
features = get_features(img) # 5s for big ct image
# print(f"features took: {time.time() - st:.2f}s")

classifier_path = pkg_resources.resource_filename('totalsegmentator', 'resources/modality_classifiers_2024_10_04.pkl')
clfs = pickle.load(open(classifier_path, "rb"))
classifier_path = pkg_resources.resource_filename('totalsegmentator', 'resources/modality_classifiers_2025_02_24.json')
clfs = {}
for fold in range(5): # assuming 5 folds
clf = xgb.XGBClassifier()
clf.load_model(f"{classifier_path}.{fold}")
clfs[fold] = clf

# ensemble across folds
preds = []
for fold, clf in clfs.items():
preds.append(clf.predict([features])[0])
preds = np.array(preds)
preds = np.mean(preds)
prediction_str = "ct" if preds < 0.5 else "mr"
probability = 1 - preds if preds < 0.5 else preds
return {"modality": prediction_str,
"probability": float(probability)}


# use normalized intensities only within rois; slower but also works if HU values are normalized
def get_modality_from_rois(img: nib.Nifti1Image):
"""
Predict modality
returns:
prediction: "ct" | "mr"
probability: float
"""
st = time.time()

organs = ["brain", "esophagus", "colon", "spinal_cord",
"scapula_left", "scapula_right",
"femur_left", "femur_right", "hip_left", "hip_right",
"gluteus_maximus_left", "gluteus_maximus_right",
"autochthon_left", "autochthon_right",
"iliopsoas_left", "iliopsoas_right"]

seg_img, stats = totalsegmentator(img, None, ml=True, fast=True, statistics=True, task="total_mr",
roi_subset=None, statistics_exclude_masks_at_border=False,
quiet=True, stats_aggregation="median", statistics_normalized_intensities=True)

features = []
for organ in organs:
features.append(stats[organ]["intensity"])
# print(f"TS took: {time.time() - st:.2f}s")

classifier_path = pkg_resources.resource_filename('totalsegmentator', 'resources/modality_classifiers_normalized_2025_02_24.json')
clfs = {}
for fold in range(5): # assuming 5 folds
clf = xgb.XGBClassifier()
clf.load_model(f"{classifier_path}.{fold}")
clfs[fold] = clf

# ensemble across folds
preds = []
Expand Down Expand Up @@ -74,9 +127,16 @@ def main():
parser.add_argument("-q", dest="quiet", action="store_true",
help="Print no output to stdout", default=False)

# Use this option if want to get modality of a image which has been normalized (does not contain original HU values anymore)
parser.add_argument("-n", dest="normalized_intensities", action="store_true",
help="Use normalized intensities within rois for prediction", default=False)

args = parser.parse_args()

res = get_modality(nib.load(args.input_file))
if args.normalized_intensities:
res = get_modality_from_rois(nib.load(args.input_file))
else:
res = get_modality(nib.load(args.input_file))

if not args.quiet:
print("Result:")
Expand Down
6 changes: 4 additions & 2 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
crop_addon=[3,3,3], roi_subset=None, output_type="nifti",
statistics=False, quiet=False, verbose=False, test=0, skip_saving=False,
device="cuda", exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False, stats_aggregation="mean", remove_small_blobs=False):
v1_order=False, stats_aggregation="mean", remove_small_blobs=False,
normalized_intensities=False):
"""
crop: string or a nibabel image
resample: None or float (target spacing for all dimensions) or list of floats
Expand Down Expand Up @@ -614,7 +615,8 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
stats_file = None
stats = get_basic_statistics(img_pred.get_fdata(), img_in_rsp, stats_file,
quiet, task_name, exclude_masks_at_border, roi_subset,
metric=stats_aggregation)
metric=stats_aggregation,
normalized_intensities=normalized_intensities)
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

if resample is not None:
Expand Down
9 changes: 6 additions & 3 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
skip_saving=False, device="gpu", license_number=None,
statistics_exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False, fastest=False, roi_subset_robust=None, stats_aggregation="mean",
remove_small_blobs=False):
remove_small_blobs=False, statistics_normalized_intensities=False):
"""
Run TotalSegmentator from within python.
Expand Down Expand Up @@ -626,7 +626,8 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
quiet=quiet, verbose=verbose, test=test, skip_saving=skip_saving, device=device,
exclude_masks_at_border=statistics_exclude_masks_at_border,
no_derived_masks=no_derived_masks, v1_order=v1_order,
stats_aggregation=stats_aggregation, remove_small_blobs=remove_small_blobs)
stats_aggregation=stats_aggregation, remove_small_blobs=remove_small_blobs,
normalized_intensities=statistics_normalized_intensities)
seg = seg_img.get_fdata().astype(np.uint8)

try:
Expand All @@ -650,7 +651,9 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
stats_file = None
stats = get_basic_statistics(seg, ct_img, stats_file,
quiet, task, statistics_exclude_masks_at_border,
roi_subset, metric=stats_aggregation)
roi_subset,
metric=stats_aggregation,
normalized_intensities=statistics_normalized_intensities)
# get_radiomics_features_for_entire_dir(input, output, output / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
10 changes: 7 additions & 3 deletions totalsegmentator/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def get_basic_statistics(seg: np.array,
task: str="total",
exclude_masks_at_border: bool=True,
roi_subset: list=None,
metric: str="mean"):
metric: str="mean",
normalized_intensities: bool=False):
"""
ct_file: path to a ct_file or a nifti file object
"""
Expand All @@ -106,6 +107,9 @@ def get_basic_statistics(seg: np.array,
spacing = ct_img.header.get_zooms()
vox_vol = spacing[0] * spacing[1] * spacing[2]

if normalized_intensities:
ct = (ct - ct.min()) / (ct.max() - ct.min())

class_map_stats = class_map[task]
if roi_subset is not None:
class_map_stats = {k: v for k, v in class_map_stats.items() if v in roi_subset}
Expand All @@ -125,9 +129,9 @@ def get_basic_statistics(seg: np.array,
st = time.time()
if metric == "mean":
# stats[mask_name]["intensity"] = ct[roi_mask > 0].mean().round(2) if roi_mask.sum() > 0 else 0.0 # 3.0s
stats[mask_name]["intensity"] = np.average(ct, weights=roi_mask).round(2) if roi_mask.sum() > 0 else 0.0 # 0.9s # fast lowres mode: 0.03s
stats[mask_name]["intensity"] = np.average(ct, weights=roi_mask).round(5) if roi_mask.sum() > 0 else 0.0 # 0.9s # fast lowres mode: 0.03s
elif metric == "median":
stats[mask_name]["intensity"] = np.median(ct[roi_mask > 0]).round(2) if roi_mask.sum() > 0 else 0.0 # 0.9s # fast lowres mode: 0.014s
stats[mask_name]["intensity"] = np.median(ct[roi_mask > 0]).round(5) if roi_mask.sum() > 0 else 0.0 # 0.9s # fast lowres mode: 0.014s
# print(f"took: {time.time()-st:.4f}s")

if file_out is not None:
Expand Down

0 comments on commit b709255

Please sign in to comment.