-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhijian-yang
committed
Jul 26, 2022
1 parent
9dac338
commit a25a75f
Showing
5 changed files
with
366 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
from PyQt5.QtGui import * | ||
from matplotlib.backends.backend_qt5 import FigureCanvasQT | ||
from PyQt5 import QtGui, QtCore, QtWidgets, uic | ||
import joblib | ||
import sys, os, time | ||
import seaborn as sns | ||
import numpy as np | ||
import pandas as pd | ||
from NiBAx.core.plotcanvas import PlotCanvas | ||
from NiBAx.core.baseplugin import BasePlugin | ||
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox | ||
from SmileGAN.Smile_GAN_clustering import clustering_result | ||
|
||
class computeSmileGANs(QtWidgets.QWidget,BasePlugin): | ||
|
||
#constructor | ||
def __init__(self): | ||
super(computeSmileGANs,self).__init__() | ||
self.model = [] | ||
root = os.path.dirname(__file__) | ||
self.readAdditionalInformation(root) | ||
self.ui = uic.loadUi(os.path.join(root, 'SmileGAN.ui'),self) | ||
self.ui.comboBoxHue = SearchableQComboBox(self.ui) | ||
self.ui.horizontalLayout_3.addWidget(self.comboBoxHue) | ||
self.plotCanvas = PlotCanvas(self.ui.page_2) | ||
self.ui.verticalLayout.addWidget(self.plotCanvas) | ||
self.plotCanvas.axes = self.plotCanvas.fig.add_subplot(111) | ||
self.SmileGANpatterns = None | ||
self.ui.stackedWidget.setCurrentIndex(0) | ||
|
||
# Initialize thread | ||
self.thread = QtCore.QThread() | ||
|
||
|
||
def getUI(self): | ||
return self.ui | ||
|
||
|
||
def SetupConnections(self): | ||
#pass | ||
self.ui.load_SmileGAN_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel()) | ||
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel()) | ||
self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame()) | ||
self.ui.compute_SmileGAN_Btn.clicked.connect(lambda check: self.OnComputePatterns(check)) | ||
self.ui.show_SmileGAN_prob_from_data_Btn.clicked.connect(lambda: self.OnShowPatterns()) | ||
self.datamodel.data_changed.connect(lambda: self.OnDataChanged()) | ||
self.ui.comboBoxHue.currentIndexChanged.connect(self.plotPattern) | ||
|
||
self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: green; color: white") | ||
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column | ||
# are present in data frame | ||
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()): | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)") | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True) | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setToolTip('The data frame has variables `SmileGAN patterns` so these can be plotted.') | ||
else: | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False) | ||
|
||
# Allow loading of SmileGAN-* model always, even when residuals are not | ||
# calculated yet | ||
self.ui.load_SmileGAN_model_Btn.setEnabled(True) | ||
|
||
|
||
def OnLoadSmileGANModel(self): | ||
fileNames, _ = QtWidgets.QFileDialog.getOpenFileNames(None, | ||
'Open SmileGAN model file', | ||
QtCore.QDir().homePath(), | ||
"") | ||
if len(fileNames) > 0: | ||
self.model = fileNames | ||
self.ui.compute_SmileGAN_Btn.setEnabled(True) | ||
model_info = 'File:' | ||
for file in fileNames: | ||
model_info += file + '\n' | ||
self.ui.SmileGAN_model_info.setText(model_info) | ||
else: | ||
return | ||
|
||
self.ui.stackedWidget.setCurrentIndex(0) | ||
|
||
if 'RES_ICV_Sex_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames(): | ||
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton" | ||
"{" | ||
"background-color : rgb(230,255,230);" | ||
"}" | ||
"QPushButton::checked" | ||
"{" | ||
"background-color : rgb(255,230,230);" | ||
"border: none;" | ||
"}" | ||
) | ||
self.ui.compute_SmileGAN_Btn.setEnabled(True) | ||
self.ui.compute_SmileGAN_Btn.setChecked(False) | ||
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded and `RES_ICV_Sex_MUSE_Volmue_*` available so the MUSE volumes can be harmonized.') | ||
else: | ||
self.ui.compute_SmileGAN_Btn.setStyleSheet("background-color: rgb(255,230,230)") | ||
self.ui.compute_SmileGAN_Btn.setEnabled(False) | ||
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded but `RES_ICV_Sex_MUSE_Volmue_*` not available so the MUSE volumes can not be harmonized.') | ||
|
||
|
||
print('No field `RES_ICV_Sex_MUSE_Volume_47` found. ' + | ||
'Make sure to compute and add harmonized residuals first.') | ||
|
||
def OnComputationDone(self, p): | ||
self.SmileGANpatterns = p | ||
self.ui.compute_SmileGAN_Btn.setText('Compute SmileGAN Patterns-*') | ||
if self.SmileGANpatterns.empty: | ||
return | ||
self.ui.compute_SmileGAN_Btn.setChecked(False) | ||
self.ui.stackedWidget.setCurrentIndex(1) | ||
self.ui.comboBoxHue.setVisible(False) | ||
self.plotPattern() | ||
|
||
# Activate buttons | ||
self.ui.compute_SmileGAN_Btn.setEnabled(False) | ||
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()): | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True) | ||
else: | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False) | ||
self.ui.load_SmileGAN_model_Btn.setEnabled(True) | ||
|
||
|
||
|
||
def OnComputePatterns(self, checked): | ||
# Setup tasks for long running jobs | ||
# Using this example: https://realpython.com/python-pyqt-qthread/ | ||
# Disable buttons | ||
if checked is not True: | ||
self.thread.requestInterruption() | ||
else: | ||
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton" | ||
"{" | ||
"background-color : rgb(230,255,230);" | ||
"}" | ||
"QPushButton::checked" | ||
"{" | ||
"background-color : rgb(255,230,230);" | ||
"}" | ||
) | ||
self.ui.compute_SmileGAN_Btn.setText('Cancel computation') | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False) | ||
self.ui.load_SmileGAN_model_Btn.setEnabled(False) | ||
self.thread = QtCore.QThread() | ||
self.worker = PatternWorker(self.datamodel.data, self.model) | ||
self.worker.moveToThread(self.thread) | ||
self.thread.started.connect(self.worker.run) | ||
self.worker.done.connect(self.thread.quit) | ||
self.worker.done.connect(self.worker.deleteLater) | ||
self.thread.finished.connect(self.thread.deleteLater) | ||
self.worker.done.connect(lambda p: self.OnComputationDone(p)) | ||
self.thread.start() | ||
|
||
|
||
def plotPattern(self): | ||
# Plot data | ||
if self.ui.stackedWidget.currentIndex() == 0: | ||
return | ||
self.plotCanvas.axes.clear() | ||
|
||
sns.countplot(x='Pattern', data=self.SmileGANpatterns, | ||
ax=self.plotCanvas.axes) | ||
|
||
sns.despine(ax=self.plotCanvas.axes, trim=True) | ||
self.plotCanvas.axes.set(ylabel='Count', xlabel='Patterns') | ||
self.plotCanvas.axes.get_figure().set_tight_layout(True) | ||
self.plotCanvas.draw() | ||
|
||
|
||
def OnAddToDataFrame(self): | ||
print('Adding SmileGAN patterns to data frame...') | ||
for col in self.SmileGANpatterns.columns: | ||
self.datamodel.data.loc[:,'SmileGAN_'+col] = self.SmileGANpatterns[col] | ||
self.datamodel.data_changed.emit() | ||
self.OnShowPatterns() | ||
|
||
|
||
def OnShowPatterns(self): | ||
self.ui.stackedWidget.setCurrentIndex(1) | ||
self.plotPattern() | ||
|
||
|
||
def OnDataChanged(self): | ||
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column | ||
# are present in data frame | ||
self.ui.stackedWidget.setCurrentIndex(0) | ||
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()): | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True) | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)") | ||
else: | ||
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False) | ||
|
||
|
||
class PatternWorker(QtCore.QObject): | ||
|
||
done = QtCore.pyqtSignal(pd.DataFrame) | ||
progress = QtCore.pyqtSignal(str, int) | ||
|
||
#constructor | ||
def __init__(self, data, model_list): | ||
super(PatternWorker, self).__init__() | ||
self.data = data | ||
self.model = model_list | ||
|
||
def run(self): | ||
train_data = self.data[['participant_id']+[ name for name in self.data.columns if ('H_MUSE_Volume' in name and int(name[14:])<300)] ] | ||
covariate = self.data[['participant_id','Age','Sex']] | ||
covariate['Sex'] = covariate['Sex'].map({'M':1,'F':0}) | ||
train_data['diagnosis'] = 1 | ||
covariate['diagnosis'] = 1 | ||
cluster_label, cluster_prob, _, _ = clustering_result(self.model, 'highest_matching_clustering', train_data, covariate) | ||
p = pd.DataFrame(data = cluster_prob, columns = ['P'+str(_) for _ in range(1,cluster_prob.shape[1]+1)]) | ||
p['Pattern'] = cluster_label | ||
|
||
# Emit the result | ||
self.done.emit(p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<ui version="4.0"> | ||
<class>Form</class> | ||
<widget class="QWidget" name="Form"> | ||
<property name="geometry"> | ||
<rect> | ||
<x>0</x> | ||
<y>0</y> | ||
<width>891</width> | ||
<height>695</height> | ||
</rect> | ||
</property> | ||
<property name="windowTitle"> | ||
<string>Compute SPAREs</string> | ||
</property> | ||
<layout class="QGridLayout" name="gridLayout"> | ||
<item row="0" column="0"> | ||
<widget class="QStackedWidget" name="stackedWidget"> | ||
<property name="styleSheet"> | ||
<string notr="true"/> | ||
</property> | ||
<property name="currentIndex"> | ||
<number>0</number> | ||
</property> | ||
<widget class="QWidget" name="page"> | ||
<layout class="QVBoxLayout" name="verticalLayout_2"> | ||
<item> | ||
<spacer name="verticalSpacer_2"> | ||
<property name="orientation"> | ||
<enum>Qt::Vertical</enum> | ||
</property> | ||
<property name="sizeHint" stdset="0"> | ||
<size> | ||
<width>338</width> | ||
<height>241</height> | ||
</size> | ||
</property> | ||
</spacer> | ||
</item> | ||
<item> | ||
<widget class="QPushButton" name="show_SmileGAN_prob_from_data_Btn"> | ||
<property name="enabled"> | ||
<bool>false</bool> | ||
</property> | ||
<property name="text"> | ||
<string>Show SmileGAN patterns from data</string> | ||
</property> | ||
</widget> | ||
</item> | ||
<item> | ||
<widget class="QPushButton" name="load_SmileGAN_model_Btn"> | ||
<property name="text"> | ||
<string>Load SmileGAN Model</string> | ||
</property> | ||
</widget> | ||
</item> | ||
<item> | ||
<widget class="QLabel" name="SmileGAN_model_info"> | ||
<property name="text"> | ||
<string>No SmileGAN model loaded</string> | ||
</property> | ||
</widget> | ||
</item> | ||
<item> | ||
<widget class="QPushButton" name="compute_SmileGAN_Btn"> | ||
<property name="enabled"> | ||
<bool>false</bool> | ||
</property> | ||
<property name="text"> | ||
<string>Compute SmileGAN patterns</string> | ||
</property> | ||
<property name="checkable"> | ||
<bool>true</bool> | ||
</property> | ||
<property name="checked"> | ||
<bool>false</bool> | ||
</property> | ||
</widget> | ||
</item> | ||
<item> | ||
<spacer name="verticalSpacer"> | ||
<property name="orientation"> | ||
<enum>Qt::Vertical</enum> | ||
</property> | ||
<property name="sizeHint" stdset="0"> | ||
<size> | ||
<width>338</width> | ||
<height>241</height> | ||
</size> | ||
</property> | ||
</spacer> | ||
</item> | ||
<item> | ||
<widget class="Line" name="line"> | ||
<property name="orientation"> | ||
<enum>Qt::Horizontal</enum> | ||
</property> | ||
</widget> | ||
</item> | ||
</layout> | ||
</widget> | ||
<widget class="QWidget" name="page_2"> | ||
<layout class="QVBoxLayout" name="verticalLayout"> | ||
<item> | ||
<layout class="QHBoxLayout" name="horizontalLayout_3"> | ||
<item> | ||
<widget class="QLabel" name="label"> | ||
<property name="text"> | ||
<string>SmileGAN</string> | ||
</property> | ||
</widget> | ||
</item> | ||
<item> | ||
<widget class="QPushButton" name="load_other_model_Btn"> | ||
<property name="text"> | ||
<string>Load other model</string> | ||
</property> | ||
</widget> | ||
</item> | ||
<item> | ||
<widget class="QPushButton" name="add_to_dataframe_Btn"> | ||
<property name="text"> | ||
<string>Add to DataFrame</string> | ||
</property> | ||
</widget> | ||
</item> | ||
</layout> | ||
</item> | ||
</layout> | ||
</widget> | ||
</widget> | ||
</item> | ||
</layout> | ||
</widget> | ||
<resources/> | ||
<connections/> | ||
</ui> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[Core] | ||
Name = SmileGAN | ||
Module = SmileGAN | ||
|
||
[Documentation] | ||
Author = Zhijian Yang | ||
Version = 0.1 | ||
Website = | ||
Description = Compute SmileGAN patterns from existing model. | ||
|
||
[Tab] | ||
#tab position starts from 0 | ||
Position = 5 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,4 @@ six==1.16.0 | |
statsmodels==0.13.0 | ||
wheel>=0.37.1 | ||
Yapsy==1.12.2 | ||
SmileGAN==0.1.2 |