Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added multiple feature in inference and fixed bugs in training #28

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions DeepLearner/DeepLearner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import glob
import os
import logging
import vtk, qt, ctk, slicer
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin
from pathlib import Path
import webbrowser
from PIL import Image

from DeepLearnerLib.CONSTANTS import DEFAULT_FILE_PATHS
from DeepLearnerLib.Asynchrony import Asynchrony
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(self, parent=None):
self.counter = None
self.webWidget = None
self.tb_log = None
self.w = None
self._asynchrony = None
self._finishCallback = None
self._running = False
Expand Down Expand Up @@ -295,14 +298,19 @@ def populateInputDirectory(self, mode="train"):
counter[f"{tp}_{feat}"] += 1
else:
counter[f"{tp}_{feat}"] = 1
if self.w is None:
rep_file = glob.glob(os.path.join(self.trainDir, sub, tp, feat, "*.png"))[0]
self.w = Image.open(rep_file).size
for f in feat_set:
self.modalityComboBox.addItem(f)
for tp in timepoints:
self.sessionComboBox.addItem(tp)
self.counter = counter

def processInputText(self, text):
return text.split(":")[1].strip().split(",")
text_lst = text.split(":")[1].strip().split(",")
text_lst = [t.strip() for t in text_lst]
return text_lst

def populateTrainDirectory(self):
self.populateInputDirectory("train")
Expand All @@ -327,18 +335,7 @@ def populateDataDirectory(self):
DEFAULT_FILE_PATHS["FEATURE_DIRS"] = self.processInputText(self.modalityComboBox.currentText)
DEFAULT_FILE_PATHS["TIME_POINTS"] = self.processInputText(self.sessionComboBox.currentText)
print(DEFAULT_FILE_PATHS)
# self.ui.msg = qt.QMessageBox()
# self.ui.msg.setIcon(qt.QMessageBox.Information)
# total_samples = self.counter[f"{self.ui.TimePointCombo.currentText}_{self.ui.FeatureNameCombo.currentText}"]
# self.ui.msg.setText(f"{total_samples} training samples available for session: "
# f"{self.ui.TimePointCombo.currentText}, modality: {self.ui.FeatureNameCombo.currentText}")
# self.ui.msg.setWindowTitle("Number of data samples")
# self.ui.msg.setInformativeText("Is it enough for training?")
# self.ui.msg.setStandardButtons(qt.QMessageBox.Yes | qt.QMessageBox.No)
# self.ui.msg.setDefaultButton(qt.QMessageBox.Yes)
# ret = self.ui.msg.exec()
return DEFAULT_FILE_PATHS, True
# return DEFAULT_FILE_PATHS, ret == qt.QMessageBox.Yes

def checkOutputDirectory(self):
"""
Expand Down Expand Up @@ -372,9 +369,10 @@ def onApplyButton(self):
self.ui.StartTrain.enabled = False
self.InputDirPushButton.enabled = False
self.ui.trainingProgressBar.setValue(0)
feat_dim = len(DEFAULT_FILE_PATHS["FEATURE_DIRS"]) * len(DEFAULT_FILE_PATHS["TIME_POINTS"]) * 2
self._asynchrony = Asynchrony(
lambda: self.logic.process(
in_channels=2,
in_channels=feat_dim,
num_classes=2,
model_name=self.model,
batch_size=int(self.ui.batchSizeSpinBox.value),
Expand All @@ -389,7 +387,8 @@ def onApplyButton(self):
monitor=self.ui.monitorLineEdit.text,
pos_weight=float(self.ui.posWLineEdit.text),
file_paths=file_paths,
ui=self.ui
ui=self.ui,
w=self.w
)
)
self._asynchrony.Start()
Expand Down Expand Up @@ -510,7 +509,8 @@ def process(
monitor="validation/valid_loss",
pos_weight=1.0,
file_paths=None,
ui=None
ui=None,
w=None
):
"""
Run the processing algorithm.
Expand All @@ -535,6 +535,7 @@ def process(
"validation/precision", "train/recall", "validation/recall"
:param file_paths: A dictionary containing relevant paths to training data. (Default: FILE_PATH object in CONSTANTS.py)
:param ui: The UI object
:param w: image size
"""
args = {
"batch_size": batch_size,
Expand All @@ -553,7 +554,8 @@ def process(
"monitor": monitor,
"pos_weight": pos_weight,
"qtProgressBarObject": ui.trainingProgressBar,
"file_paths": file_paths
"file_paths": file_paths,
"w": w
}
import time
from DeepLearnerLib.training.EfficientNetTrainer import cli_main
Expand Down
3 changes: 2 additions & 1 deletion DeepLearner/DeepLearnerLib/CONSTANTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"FEATURE_DIRS": ["eacsf"],
"FILE_SUFFIX": ["_flat", "_flat"],
"TIME_POINTS": ["V06"],
"FILE_EXT": ".png"
"FILE_EXT": ".png",
"IMAGE_SIZE": 128
}

# HYPERPARAMERS = {
Expand Down
4 changes: 2 additions & 2 deletions DeepLearner/DeepLearnerLib/data_utils/CustomDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
return np.concatenate(self.transforms(self.image_files[idx]), axis=0), \
self.labels[idx]
input = np.concatenate(self.transforms(self.image_files[idx]), axis=0)
return input, self.labels[idx]
2 changes: 2 additions & 0 deletions DeepLearner/DeepLearnerLib/data_utils/GeomCnnDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from DeepLearnerLib.data_utils.utils import get_image_files_single_scalar
from DeepLearnerLib.data_utils.CustomDataset import GeomCnnDataset
from sklearn.model_selection import train_test_split
from PIL import Image


class GeomCnnDataModule(pl.LightningDataModule):
Expand All @@ -32,6 +33,7 @@ def __init__(self,
data_tuple=None,
file_paths=None):
super(GeomCnnDataModule, self).__init__()
self.w = None
self.batch_size = batch_size
self.num_workers = num_workers
self.val_frac = val_frac
Expand Down
42 changes: 21 additions & 21 deletions DeepLearner/DeepLearnerLib/data_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,34 @@ def get_image_files_single_scalar(data_dir="TRAIN_DATA_DIR", FILE_PATHS=None):
if FILE_PATHS is None:
FILE_PATHS = DEFAULT_FILE_PATHS
subject_ids = sorted(os.listdir(FILE_PATHS[data_dir]))
scalars = FILE_PATHS["FEATURE_DIRS"][0]
time_points = FILE_PATHS["TIME_POINTS"][0]
scalars = FILE_PATHS["FEATURE_DIRS"]
time_points = FILE_PATHS["TIME_POINTS"]
attr = get_attributes(FILE_PATHS)
count = {"HR-neg": 0, "HR-ASD": 1}
print(attr.head())
count = {0: 0, 1: 1}
for sub in subject_ids:
feat_tuple = []
sub_path = os.path.join(FILE_PATHS[data_dir], sub, time_points)
if not os.path.isdir(sub_path) or not os.path.isdir(os.path.join(sub_path, scalars)):
continue
n_feat = [os.path.join(sub_path, f) for f in os.listdir(sub_path)
if os.path.isdir(os.path.join(sub_path, f))]
if len(n_feat) == 0:
if not os.path.isdir(os.path.join(FILE_PATHS[data_dir], sub)):
continue
feat_tuple = []
sub_paths = [os.path.join(FILE_PATHS[data_dir], sub, t) for t in time_points]
sub_attr = attr.loc[attr["CandID"] == int(sub)]
if sub_attr.size == 0:
continue
group = sub_attr["group"].values[0]
if "LR" in group:
continue
elif group == "HR-neg":
labels.append(0)
else:
labels.append(1)
group = int(sub_attr["group"].values[0])
labels.append(group)
count[group] += 1
feat_tuple.append(os.path.join(sub_path, scalars, "left_" + scalars +
FILE_PATHS["FILE_SUFFIX"][0]) + FILE_PATHS["FILE_EXT"])
feat_tuple.append(os.path.join(sub_path, scalars, "right_" + scalars +
FILE_PATHS["FILE_SUFFIX"][0]) + FILE_PATHS["FILE_EXT"])
for sub_path in sub_paths:
if not os.path.isdir(sub_path):
continue
n_feat = [os.path.join(sub_path, f) for f in scalars
if os.path.isdir(os.path.join(sub_path, f))]
if len(n_feat) == 0:
continue
for s in scalars:
feat_tuple.append(os.path.join(sub_path, s, "left_" + s +
FILE_PATHS["FILE_SUFFIX"][0]) + FILE_PATHS["FILE_EXT"])
feat_tuple.append(os.path.join(sub_path, s, "right_" + s +
FILE_PATHS["FILE_SUFFIX"][1]) + FILE_PATHS["FILE_EXT"])
file_names.append(feat_tuple)
print(count)
return file_names, labels
Expand Down
4 changes: 2 additions & 2 deletions DeepLearner/DeepLearnerLib/models/cnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class SimpleCNN(nn.Module):
def __init__(self, in_channels=2, dropout=0.05, n_classes=2):
def __init__(self, in_channels=2, dropout=0.05, n_classes=2, w=512):
super(SimpleCNN, self).__init__()
self.conv1 = Convolution(
spatial_dims=2,
Expand All @@ -30,7 +30,7 @@ def __init__(self, in_channels=2, dropout=0.05, n_classes=2):
)
self.mxpool = nn.MaxPool2d(4)
self.out_head = nn.Sequential(
nn.Linear(1024, 64),
nn.Linear(w//64 * w//64 * 16, 64),
nn.PReLU(),
nn.Linear(64, 64),
nn.PReLU(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn as nn
import torch
import torchmetrics
from DeepLearnerLib.Asynchrony import Asynchrony


class ImageClassifier(pl.LightningModule):
Expand Down
64 changes: 35 additions & 29 deletions DeepLearner/DeepLearnerLib/training/EfficientNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@
except ImportError:
slicer.util.pip_install('scikit-learn==0.24.2')

try:
import PIL
except ImportError:
slicer.util.pip_install("Pillow==8.3.1")

import torch.nn
from monai.networks.nets import EfficientNetBN, DenseNet121, DenseNet, SEResNet50

from DeepLearnerLib.models.cnn_model import SimpleCNN
from DeepLearnerLib.pl_modules.classifier_modules import ImageClassifier
from DeepLearnerLib.CONSTANTS import DEFAULT_FILE_PATHS
from DeepLearnerLib.data_utils.GeomCnnDataset import GeomCnnDataModule, GeomCnnDataModuleKFold
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
Expand Down Expand Up @@ -75,11 +80,29 @@ def on_train_epoch_end(self, trainer, pl_module, **kwargs):


def cli_main(args):
# pl.seed_everything(1234)
# -----------
# Data
# -----------
if args["n_folds"] == 1:
data_modules = [
GeomCnnDataModule(
batch_size=args["batch_size"],
num_workers=args["data_workers"],
file_paths=args["file_paths"]
)
]
else:
data_module_generator = GeomCnnDataModuleKFold(
batch_size=args["batch_size"],
num_workers=args["data_workers"],
n_splits=args["n_folds"],
file_paths=args["file_paths"]
)
data_modules = data_module_generator.get_folds()

# ------------
# model
# ------------
# ------------
# model
# ------------
if args["model"] == "eff_bn":
backbone = EfficientNetBN(
model_name="efficientnet-b0",
Expand All @@ -101,33 +124,16 @@ def cli_main(args):
pretrained=True
)
else:
backbone = SimpleCNN()
backbone = SimpleCNN(
in_channels=args["in_channels"],
w=args["w"][0]
)
device = "cuda:0" if torch.cuda.is_available() and args["use_gpu"] else "cpu"
model = ImageClassifier(backbone, learning_rate=args["learning_rate"],
criterion=torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([1.0, args["pos_weight"]])),
device=device,
metrics=["acc", "precision", "recall"])

# -----------
# Data
# -----------
if args["n_folds"] == 1:
data_modules = [
GeomCnnDataModule(
batch_size=args["batch_size"],
num_workers=args["data_workers"],
file_paths=args["file_paths"]
)
]

else:
data_module_generator = GeomCnnDataModuleKFold(
batch_size=args["batch_size"],
num_workers=args["data_workers"],
n_splits=args["n_folds"],
file_paths=args["file_paths"]
)
data_modules = data_module_generator.get_folds()
criterion=torch.nn.CrossEntropyLoss(
weight=torch.FloatTensor([1.0, args["pos_weight"]])),
device=device,
metrics=["acc", "precision", "recall"])

for i in range(args["n_folds"]):
# logger
Expand Down
45 changes: 45 additions & 0 deletions Inference/CheckableComboBox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from qt import QApplication, QComboBox, QMainWindow, QWidget, QVBoxLayout, QStandardItemModel, Qt


# creating checkable combo box class
class CheckableComboBox(QComboBox):
def __init__(self):
super(CheckableComboBox, self).__init__()
self.view().pressed.connect(self.handle_item_pressed)
self.setModel(QStandardItemModel(self))

# when any item get pressed
def handle_item_pressed(self, index):
# getting which item is pressed
item = self.model().itemFromIndex(index)
# make it check if unchecked and vice-versa
if item.checkState() == Qt.Checked:
item.setCheckState(Qt.Unchecked)
else:
item.setCheckState(Qt.Checked)
self.check_items()

# method called by check_items
def item_checked(self, index):
item = self.model().item(index, 0)
return item.checkState() == Qt.Checked

# calling method
def check_items(self):
checkedItems = []
for i in range(self.count):
if self.item_checked(i):
checkedItems.append(i)
self.update_labels(checkedItems)

# method to update the label
def update_labels(self, item_list):
item_text = [self.model().item(i, 0).text().split("-")[0].strip()
for i in range(self.count)]
if len(item_list) > 0:
n = ", ".join([item_text[i] for i in item_list])
item_text_new = [txt + ' - selected items: ' + n for txt in item_text]
else:
item_text_new = item_text
for i in range(self.count):
self.setItemText(i, item_text_new[i])
Loading