Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SmileGAN Plugin #191

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions NiBAx/plugins/SmileGAN/SmileGAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from PyQt5.QtGui import *
from matplotlib.backends.backend_qt5 import FigureCanvasQT
from PyQt5 import QtGui, QtCore, QtWidgets, uic
import joblib
import sys, os, time
import seaborn as sns
import numpy as np
import pandas as pd
from NiBAx.core.plotcanvas import PlotCanvas
from NiBAx.core.baseplugin import BasePlugin
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox
from SmileGAN.Smile_GAN_clustering import clustering_result

class computeSmileGANs(QtWidgets.QWidget,BasePlugin):

#constructor
def __init__(self):
super(computeSmileGANs,self).__init__()
self.model = []
root = os.path.dirname(__file__)
self.readAdditionalInformation(root)
self.ui = uic.loadUi(os.path.join(root, 'SmileGAN.ui'),self)
self.ui.comboBoxHue = SearchableQComboBox(self.ui)
self.ui.horizontalLayout_3.addWidget(self.comboBoxHue)
self.plotCanvas = PlotCanvas(self.ui.page_2)
self.ui.verticalLayout.addWidget(self.plotCanvas)
self.plotCanvas.axes = self.plotCanvas.fig.add_subplot(111)
self.SmileGANpatterns = None
self.ui.stackedWidget.setCurrentIndex(0)

# Initialize thread
self.thread = QtCore.QThread()


def getUI(self):
return self.ui


def SetupConnections(self):
#pass
self.ui.load_SmileGAN_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame())
self.ui.compute_SmileGAN_Btn.clicked.connect(lambda check: self.OnComputePatterns(check))
self.ui.show_SmileGAN_prob_from_data_Btn.clicked.connect(lambda: self.OnShowPatterns())
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
self.ui.comboBoxHue.currentIndexChanged.connect(self.plotPattern)

self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: green; color: white")
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column
# are present in data frame
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
self.ui.show_SmileGAN_prob_from_data_Btn.setToolTip('The data frame has variables `SmileGAN patterns` so these can be plotted.')
else:
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)

# Allow loading of SmileGAN-* model always, even when residuals are not
# calculated yet
self.ui.load_SmileGAN_model_Btn.setEnabled(True)


def OnLoadSmileGANModel(self):
fileNames, _ = QtWidgets.QFileDialog.getOpenFileNames(None,
'Open SmileGAN model file',
QtCore.QDir().homePath(),
"")
if len(fileNames) > 0:
self.model = fileNames
self.ui.compute_SmileGAN_Btn.setEnabled(True)
model_info = 'File:'
for file in fileNames:
model_info += file + '\n'
self.ui.SmileGAN_model_info.setText(model_info)
else:
return

self.ui.stackedWidget.setCurrentIndex(0)

if 'RES_ICV_Sex_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames():
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
"{"
"background-color : rgb(230,255,230);"
"}"
"QPushButton::checked"
"{"
"background-color : rgb(255,230,230);"
"border: none;"
"}"
)
self.ui.compute_SmileGAN_Btn.setEnabled(True)
self.ui.compute_SmileGAN_Btn.setChecked(False)
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded and `RES_ICV_Sex_MUSE_Volmue_*` available so the MUSE volumes can be harmonized.')
else:
self.ui.compute_SmileGAN_Btn.setStyleSheet("background-color: rgb(255,230,230)")
self.ui.compute_SmileGAN_Btn.setEnabled(False)
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded but `RES_ICV_Sex_MUSE_Volmue_*` not available so the MUSE volumes can not be harmonized.')


print('No field `RES_ICV_Sex_MUSE_Volume_47` found. ' +
'Make sure to compute and add harmonized residuals first.')

def OnComputationDone(self, p):
self.SmileGANpatterns = p
self.ui.compute_SmileGAN_Btn.setText('Compute SmileGAN Patterns-*')
if self.SmileGANpatterns.empty:
return
self.ui.compute_SmileGAN_Btn.setChecked(False)
self.ui.stackedWidget.setCurrentIndex(1)
self.ui.comboBoxHue.setVisible(False)
self.plotPattern()

# Activate buttons
self.ui.compute_SmileGAN_Btn.setEnabled(False)
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
else:
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
self.ui.load_SmileGAN_model_Btn.setEnabled(True)



def OnComputePatterns(self, checked):
# Setup tasks for long running jobs
# Using this example: https://realpython.com/python-pyqt-qthread/
# Disable buttons
if checked is not True:
self.thread.requestInterruption()
else:
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
"{"
"background-color : rgb(230,255,230);"
"}"
"QPushButton::checked"
"{"
"background-color : rgb(255,230,230);"
"}"
)
self.ui.compute_SmileGAN_Btn.setText('Cancel computation')
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
self.ui.load_SmileGAN_model_Btn.setEnabled(False)
self.thread = QtCore.QThread()
self.worker = PatternWorker(self.datamodel.data, self.model)
self.worker.moveToThread(self.thread)
self.thread.started.connect(self.worker.run)
self.worker.done.connect(self.thread.quit)
self.worker.done.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
self.worker.done.connect(lambda p: self.OnComputationDone(p))
self.thread.start()


def plotPattern(self):
# Plot data
if self.ui.stackedWidget.currentIndex() == 0:
return
self.plotCanvas.axes.clear()

sns.countplot(x='Pattern', data=self.SmileGANpatterns,
ax=self.plotCanvas.axes)

sns.despine(ax=self.plotCanvas.axes, trim=True)
self.plotCanvas.axes.set(ylabel='Count', xlabel='Patterns')
self.plotCanvas.axes.get_figure().set_tight_layout(True)
self.plotCanvas.draw()


def OnAddToDataFrame(self):
print('Adding SmileGAN patterns to data frame...')
for col in self.SmileGANpatterns.columns:
self.datamodel.data.loc[:,'SmileGAN_'+col] = self.SmileGANpatterns[col]
self.datamodel.data_changed.emit()
self.OnShowPatterns()


def OnShowPatterns(self):
self.ui.stackedWidget.setCurrentIndex(1)
self.plotPattern()


def OnDataChanged(self):
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column
# are present in data frame
self.ui.stackedWidget.setCurrentIndex(0)
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
else:
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)


class PatternWorker(QtCore.QObject):

done = QtCore.pyqtSignal(pd.DataFrame)
progress = QtCore.pyqtSignal(str, int)

#constructor
def __init__(self, data, model_list):
super(PatternWorker, self).__init__()
self.data = data
self.model = model_list

def run(self):
train_data = self.data[['participant_id']+[ name for name in self.data.columns if ('H_MUSE_Volume' in name and int(name[14:])<300)] ]
covariate = self.data[['participant_id','Age','Sex']]
covariate['Sex'] = covariate['Sex'].map({'M':1,'F':0})
train_data['diagnosis'] = 1
covariate['diagnosis'] = 1
cluster_label, cluster_prob, _, _ = clustering_result(self.model, 'highest_matching_clustering', train_data, covariate)
p = pd.DataFrame(data = cluster_prob, columns = ['P'+str(_) for _ in range(1,cluster_prob.shape[1]+1)])
p['Pattern'] = cluster_label

# Emit the result
self.done.emit(p)
137 changes: 137 additions & 0 deletions NiBAx/plugins/SmileGAN/SmileGAN.ui
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>Form</class>
<widget class="QWidget" name="Form">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>891</width>
<height>695</height>
</rect>
</property>
<property name="windowTitle">
<string>Compute SPAREs</string>
</property>
<layout class="QGridLayout" name="gridLayout">
<item row="0" column="0">
<widget class="QStackedWidget" name="stackedWidget">
<property name="styleSheet">
<string notr="true"/>
</property>
<property name="currentIndex">
<number>0</number>
</property>
<widget class="QWidget" name="page">
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<spacer name="verticalSpacer_2">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>338</width>
<height>241</height>
</size>
</property>
</spacer>
</item>
<item>
<widget class="QPushButton" name="show_SmileGAN_prob_from_data_Btn">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Show SmileGAN patterns from data</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="load_SmileGAN_model_Btn">
<property name="text">
<string>Load SmileGAN Model</string>
</property>
</widget>
</item>
<item>
<widget class="QLabel" name="SmileGAN_model_info">
<property name="text">
<string>No SmileGAN model loaded</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="compute_SmileGAN_Btn">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Compute SmileGAN patterns</string>
</property>
<property name="checkable">
<bool>true</bool>
</property>
<property name="checked">
<bool>false</bool>
</property>
</widget>
</item>
<item>
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>338</width>
<height>241</height>
</size>
</property>
</spacer>
</item>
<item>
<widget class="Line" name="line">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
</widget>
</item>
</layout>
</widget>
<widget class="QWidget" name="page_2">
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<layout class="QHBoxLayout" name="horizontalLayout_3">
<item>
<widget class="QLabel" name="label">
<property name="text">
<string>SmileGAN</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="load_other_model_Btn">
<property name="text">
<string>Load other model</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="add_to_dataframe_Btn">
<property name="text">
<string>Add to DataFrame</string>
</property>
</widget>
</item>
</layout>
</item>
</layout>
</widget>
</widget>
</item>
</layout>
</widget>
<resources/>
<connections/>
</ui>
13 changes: 13 additions & 0 deletions NiBAx/plugins/SmileGAN/SmileGAN.yapsy-plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[Core]
Name = SmileGAN
Module = SmileGAN

[Documentation]
Author = Zhijian Yang
Version = 0.1
Website =
Description = Compute SmileGAN patterns from existing model.

[Tab]
#tab position starts from 0
Position = 5
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ six==1.16.0
statsmodels==0.13.0
wheel>=0.37.1
Yapsy==1.12.2
SmileGAN==0.1.2