Skip to content

Commit

Permalink
Add SmileGAN Plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijian-yang committed Jul 26, 2022
1 parent 9dac338 commit a25a75f
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 0 deletions.
215 changes: 215 additions & 0 deletions NiBAx/plugins/SmileGAN/SmileGAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from PyQt5.QtGui import *
from matplotlib.backends.backend_qt5 import FigureCanvasQT
from PyQt5 import QtGui, QtCore, QtWidgets, uic
import joblib
import sys, os, time
import seaborn as sns
import numpy as np
import pandas as pd
from NiBAx.core.plotcanvas import PlotCanvas
from NiBAx.core.baseplugin import BasePlugin
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox
from SmileGAN.Smile_GAN_clustering import clustering_result

class computeSmileGANs(QtWidgets.QWidget,BasePlugin):

#constructor
def __init__(self):
super(computeSmileGANs,self).__init__()
self.model = []
root = os.path.dirname(__file__)
self.readAdditionalInformation(root)
self.ui = uic.loadUi(os.path.join(root, 'SmileGAN.ui'),self)
self.ui.comboBoxHue = SearchableQComboBox(self.ui)
self.ui.horizontalLayout_3.addWidget(self.comboBoxHue)
self.plotCanvas = PlotCanvas(self.ui.page_2)
self.ui.verticalLayout.addWidget(self.plotCanvas)
self.plotCanvas.axes = self.plotCanvas.fig.add_subplot(111)
self.SmileGANpatterns = None
self.ui.stackedWidget.setCurrentIndex(0)

# Initialize thread
self.thread = QtCore.QThread()


def getUI(self):
return self.ui


def SetupConnections(self):
#pass
self.ui.load_SmileGAN_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame())
self.ui.compute_SmileGAN_Btn.clicked.connect(lambda check: self.OnComputePatterns(check))
self.ui.show_SmileGAN_prob_from_data_Btn.clicked.connect(lambda: self.OnShowPatterns())
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
self.ui.comboBoxHue.currentIndexChanged.connect(self.plotPattern)

self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: green; color: white")
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column
# are present in data frame
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
self.ui.show_SmileGAN_prob_from_data_Btn.setToolTip('The data frame has variables `SmileGAN patterns` so these can be plotted.')
else:
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)

# Allow loading of SmileGAN-* model always, even when residuals are not
# calculated yet
self.ui.load_SmileGAN_model_Btn.setEnabled(True)


def OnLoadSmileGANModel(self):
fileNames, _ = QtWidgets.QFileDialog.getOpenFileNames(None,
'Open SmileGAN model file',
QtCore.QDir().homePath(),
"")
if len(fileNames) > 0:
self.model = fileNames
self.ui.compute_SmileGAN_Btn.setEnabled(True)
model_info = 'File:'
for file in fileNames:
model_info += file + '\n'
self.ui.SmileGAN_model_info.setText(model_info)
else:
return

self.ui.stackedWidget.setCurrentIndex(0)

if 'RES_ICV_Sex_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames():
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
"{"
"background-color : rgb(230,255,230);"
"}"
"QPushButton::checked"
"{"
"background-color : rgb(255,230,230);"
"border: none;"
"}"
)
self.ui.compute_SmileGAN_Btn.setEnabled(True)
self.ui.compute_SmileGAN_Btn.setChecked(False)
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded and `RES_ICV_Sex_MUSE_Volmue_*` available so the MUSE volumes can be harmonized.')
else:
self.ui.compute_SmileGAN_Btn.setStyleSheet("background-color: rgb(255,230,230)")
self.ui.compute_SmileGAN_Btn.setEnabled(False)
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded but `RES_ICV_Sex_MUSE_Volmue_*` not available so the MUSE volumes can not be harmonized.')


print('No field `RES_ICV_Sex_MUSE_Volume_47` found. ' +
'Make sure to compute and add harmonized residuals first.')

def OnComputationDone(self, p):
self.SmileGANpatterns = p
self.ui.compute_SmileGAN_Btn.setText('Compute SmileGAN Patterns-*')
if self.SmileGANpatterns.empty:
return
self.ui.compute_SmileGAN_Btn.setChecked(False)
self.ui.stackedWidget.setCurrentIndex(1)
self.ui.comboBoxHue.setVisible(False)
self.plotPattern()

# Activate buttons
self.ui.compute_SmileGAN_Btn.setEnabled(False)
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
else:
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
self.ui.load_SmileGAN_model_Btn.setEnabled(True)



def OnComputePatterns(self, checked):
# Setup tasks for long running jobs
# Using this example: https://realpython.com/python-pyqt-qthread/
# Disable buttons
if checked is not True:
self.thread.requestInterruption()
else:
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
"{"
"background-color : rgb(230,255,230);"
"}"
"QPushButton::checked"
"{"
"background-color : rgb(255,230,230);"
"}"
)
self.ui.compute_SmileGAN_Btn.setText('Cancel computation')
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
self.ui.load_SmileGAN_model_Btn.setEnabled(False)
self.thread = QtCore.QThread()
self.worker = PatternWorker(self.datamodel.data, self.model)
self.worker.moveToThread(self.thread)
self.thread.started.connect(self.worker.run)
self.worker.done.connect(self.thread.quit)
self.worker.done.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
self.worker.done.connect(lambda p: self.OnComputationDone(p))
self.thread.start()


def plotPattern(self):
# Plot data
if self.ui.stackedWidget.currentIndex() == 0:
return
self.plotCanvas.axes.clear()

sns.countplot(x='Pattern', data=self.SmileGANpatterns,
ax=self.plotCanvas.axes)

sns.despine(ax=self.plotCanvas.axes, trim=True)
self.plotCanvas.axes.set(ylabel='Count', xlabel='Patterns')
self.plotCanvas.axes.get_figure().set_tight_layout(True)
self.plotCanvas.draw()


def OnAddToDataFrame(self):
print('Adding SmileGAN patterns to data frame...')
for col in self.SmileGANpatterns.columns:
self.datamodel.data.loc[:,'SmileGAN_'+col] = self.SmileGANpatterns[col]
self.datamodel.data_changed.emit()
self.OnShowPatterns()


def OnShowPatterns(self):
self.ui.stackedWidget.setCurrentIndex(1)
self.plotPattern()


def OnDataChanged(self):
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column
# are present in data frame
self.ui.stackedWidget.setCurrentIndex(0)
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
else:
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)


class PatternWorker(QtCore.QObject):

done = QtCore.pyqtSignal(pd.DataFrame)
progress = QtCore.pyqtSignal(str, int)

#constructor
def __init__(self, data, model_list):
super(PatternWorker, self).__init__()
self.data = data
self.model = model_list

def run(self):
train_data = self.data[['participant_id']+[ name for name in self.data.columns if ('H_MUSE_Volume' in name and int(name[14:])<300)] ]
covariate = self.data[['participant_id','Age','Sex']]
covariate['Sex'] = covariate['Sex'].map({'M':1,'F':0})
train_data['diagnosis'] = 1
covariate['diagnosis'] = 1
cluster_label, cluster_prob, _, _ = clustering_result(self.model, 'highest_matching_clustering', train_data, covariate)
p = pd.DataFrame(data = cluster_prob, columns = ['P'+str(_) for _ in range(1,cluster_prob.shape[1]+1)])
p['Pattern'] = cluster_label

# Emit the result
self.done.emit(p)
137 changes: 137 additions & 0 deletions NiBAx/plugins/SmileGAN/SmileGAN.ui
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>Form</class>
<widget class="QWidget" name="Form">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>891</width>
<height>695</height>
</rect>
</property>
<property name="windowTitle">
<string>Compute SPAREs</string>
</property>
<layout class="QGridLayout" name="gridLayout">
<item row="0" column="0">
<widget class="QStackedWidget" name="stackedWidget">
<property name="styleSheet">
<string notr="true"/>
</property>
<property name="currentIndex">
<number>0</number>
</property>
<widget class="QWidget" name="page">
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<spacer name="verticalSpacer_2">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>338</width>
<height>241</height>
</size>
</property>
</spacer>
</item>
<item>
<widget class="QPushButton" name="show_SmileGAN_prob_from_data_Btn">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Show SmileGAN patterns from data</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="load_SmileGAN_model_Btn">
<property name="text">
<string>Load SmileGAN Model</string>
</property>
</widget>
</item>
<item>
<widget class="QLabel" name="SmileGAN_model_info">
<property name="text">
<string>No SmileGAN model loaded</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="compute_SmileGAN_Btn">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Compute SmileGAN patterns</string>
</property>
<property name="checkable">
<bool>true</bool>
</property>
<property name="checked">
<bool>false</bool>
</property>
</widget>
</item>
<item>
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>338</width>
<height>241</height>
</size>
</property>
</spacer>
</item>
<item>
<widget class="Line" name="line">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
</widget>
</item>
</layout>
</widget>
<widget class="QWidget" name="page_2">
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<layout class="QHBoxLayout" name="horizontalLayout_3">
<item>
<widget class="QLabel" name="label">
<property name="text">
<string>SmileGAN</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="load_other_model_Btn">
<property name="text">
<string>Load other model</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="add_to_dataframe_Btn">
<property name="text">
<string>Add to DataFrame</string>
</property>
</widget>
</item>
</layout>
</item>
</layout>
</widget>
</widget>
</item>
</layout>
</widget>
<resources/>
<connections/>
</ui>
13 changes: 13 additions & 0 deletions NiBAx/plugins/SmileGAN/SmileGAN.yapsy-plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[Core]
Name = SmileGAN
Module = SmileGAN

[Documentation]
Author = Zhijian Yang
Version = 0.1
Website =
Description = Compute SmileGAN patterns from existing model.

[Tab]
#tab position starts from 0
Position = 5
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ six==1.16.0
statsmodels==0.13.0
wheel>=0.37.1
Yapsy==1.12.2
SmileGAN==0.1.2

0 comments on commit a25a75f

Please sign in to comment.