Skip to content

Commit

Permalink
Head plots added
Browse files Browse the repository at this point in the history
  • Loading branch information
esantamariavazquez committed Feb 15, 2023
1 parent 46cb5e4 commit ecc66a0
Show file tree
Hide file tree
Showing 6 changed files with 615 additions and 110 deletions.
9 changes: 7 additions & 2 deletions medusa/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(self, key, info, value_type=None, value=None):
Tree item key
info: str
Information about this item
value_type: str ['string'|'number'|'boolean'|'dict'|'list'], optional
value_type: str ['string'|'number'|'integer'|'boolean'|'dict'|'list'], optional
Type of the data stored in attribute value. Leave to None if the
item is going to be a tree.
value: str, int, float, bool, dict or list, optional
Expand Down Expand Up @@ -314,6 +314,11 @@ def set_data(self, value_type, value):
assert isinstance(value, int) or isinstance(value, float), \
'Parameter value must be of types %s or %s' % \
(int, float)
elif t == 'integer':
if value is not None:
assert isinstance(value, int), \
'Parameter value must be of types %s or %s' % \
(int, float)
elif t == 'boolean':
if value is not None:
assert isinstance(value, bool), \
Expand Down Expand Up @@ -346,7 +351,7 @@ def set_data(self, value_type, value):
def add_item(self, item):
"""Adds tree item to the tree. Use this function to build a custom tree.
Take into account that if this function is used, attributes value and
type will be set to None and 'tree', respectively.
type will be set to None.
Parameters
----------
Expand Down
223 changes: 120 additions & 103 deletions medusa/plots/brain_plots.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""
Created on Mon Mar 21 15:34:18 2022
@author: Diego Marcos-Martínez
"""
# Python modules
import argparse
import os
Expand All @@ -20,7 +15,8 @@
from medusa.meeg import EEGChannelSet


class TridimentionalBrain():
class TridimentionalBrain:

def __init__(self, bg_color='black', text_color='white', translucent=None,
subplots=None, models=None, names=None):

Expand All @@ -47,6 +43,7 @@ def __init__(self, bg_color='black', text_color='white', translucent=None,
self.lines_cmap = None
self.n_subplots = None

# Check subplots
self.__check_subplot()

# Set canvas and view
Expand Down Expand Up @@ -83,7 +80,7 @@ def __set_alpha(self):
try:
if self.translucent is None:
self.translucent = np.ones(self.n_subplots, dtype=bool)
elif isinstance(self.translucent,bool):
elif isinstance(self.translucent, bool):
self.translucent = [self.translucent]
else:
assert len(self.translucent) == self.n_subplots
Expand Down Expand Up @@ -176,18 +173,18 @@ def __set_brain_visual(self):
color=(1, 1, 1, alpha))
mesh.shading_filter.shininess = 1e+2
self.brain_visuals.append(mesh)
self.attach_headlight(mesh, self.views[i], self.canvas)
self.__attach_headlight(mesh, self.views[i], self.canvas)
except Exception as ex:
print(ex)

def add_brains(self):
def __add_brains(self):
try:
for view_idx in range(len(self.views)):
self.views[view_idx].add(self.brain_visuals[view_idx])
except Exception as ex:
print(ex)

def attach_headlight(self, mesh, view, canvas):
def __attach_headlight(self, mesh, view, canvas):
"""This function sets the initial light direction """
light_dir = (1, 0, 1, 1)
mesh.shading_filter.light_dir = light_dir[:3]
Expand All @@ -200,11 +197,11 @@ def on_transform_change(event):
transform = view.camera.transform
mesh.shading_filter.light_dir = transform.map(initial_light_dir)[:3]

def set_markers(self, locs, sub_plot):
def __set_markers(self, locs, sub_plot):
try:
if self.markers is None:
self.__init_markers()
markers = Markers(light_color=self.text_color,size = 100)
markers = Markers(light_color=self.text_color, size=100)
markers.set_data(locs)
self.markers[sub_plot[0]][sub_plot[1]] = markers

Expand All @@ -213,7 +210,7 @@ def set_markers(self, locs, sub_plot):
except Exception as ex:
print(ex)

def set_labels(self, labels, locs, sub_plot):
def __set_labels(self, labels, locs, sub_plot):
try:
if self.labels_text is None:
self.__init_labels()
Expand Down Expand Up @@ -257,19 +254,82 @@ def __set_conn_color(self, sub_plot, threshold, clim):
if clim is None:
clim = []
clim.append(np.round(np.min(self.connections_values[sub_plot[0]]
[sub_plot[1]]),decimals=2))
[sub_plot[1]]), decimals=2))
clim.append(np.round(np.max(self.connections_values[sub_plot[0]]
[sub_plot[1]]),decimals = 2))
[sub_plot[1]]), decimals=2))

color = self.lines_cmap.map((self.connections_values[sub_plot[0]]
[sub_plot[1]] - clim[0])/(clim[1] - clim[0]))
color = self.lines_cmap.map(
(self.connections_values[sub_plot[0]][sub_plot[1]] - clim[0]) /
(clim[1] - clim[0]))
return color, clim
except Exception as ex:
print(ex)

def __calculate_subplot_idx(self, sub_plot):
try:
_view_idx = sub_plot[1] + self.subplots[1] * sub_plot[0]
return _view_idx
except Exception as ex:
print(ex)

def __initialize_connections(self):
try:
subplots = self.subplots
if self.n_subplots is None:
subplots = (1, 1)
self.connections_coords_invariant = np.empty(
shape=subplots + (0,)).tolist()
self.connections_coords_mutable = np.empty(
shape=subplots + (0,)).tolist()
self.connections_values = np.empty(shape=subplots + (0,)).tolist()
self.lines = np.empty(shape=subplots + (0,)).tolist()
except Exception as ex:
print(ex)

def __init_markers(self):
try:
subplots = self.subplots
self.markers = np.empty(shape=subplots + (0,)).tolist()
except Exception as ex:
print(ex)

def __init_labels(self):
try:
subplots = self.subplots
self.labels_text = np.empty(shape=subplots + (0,)).tolist()
except Exception as ex:
print(ex)

@staticmethod
def __extract_conn_values(adj_mat):
try:
values = np.triu(adj_mat, 1)
values = values[np.where(values != 0)]
return values
except Exception as ex:
print(ex)

@staticmethod
def __set_connections_coords(locs, connections_values):
try:
connections_coords = np.empty((len(connections_values), 2, 3))
value_idx = 0
for i in range(len(locs)):
for j in range(i + 1, len(locs)):
connections_coords[value_idx, 0, 0] = locs[i, 0]
connections_coords[value_idx, 0, 1] = locs[i, 1]
connections_coords[value_idx, 0, 2] = locs[i, 2]
connections_coords[value_idx, 1, 0] = locs[j, 0]
connections_coords[value_idx, 1, 1] = locs[j, 1]
connections_coords[value_idx, 1, 2] = locs[j, 2]
value_idx += 1
return connections_coords
except Exception as ex:
print(ex)

def set_connections(self, adj_mat, locs, sub_plot=None, threshold=0.5,
plot_markers=True, labels=None, plot_labels=False,
cmap='seismic', clim = None, cbar = False):
cmap='seismic', clim=None, cbar=False):
try:
if sub_plot is None and self.n_subplots == 1:
sub_plot = (0, 0)
Expand All @@ -285,34 +345,30 @@ def set_connections(self, adj_mat, locs, sub_plot=None, threshold=0.5,
"is None")
else:
assert all(isinstance(elem, str) for elem in labels)
self.set_labels(labels, locs, sub_plot)
self.__set_labels(labels, locs, sub_plot)

if plot_markers:
self.set_markers(locs, sub_plot)

self.__set_markers(locs, sub_plot)

# Extract connectivity values
self.connections_values[sub_plot[0]][
sub_plot[1]] = self.__extract_conn_values(adj_mat)

# Set connections matrix
self.connections_coords_invariant[sub_plot[0]][
sub_plot[1]] = self.__set_connections_coords(locs,
self.connections_values
[sub_plot[0]][
sub_plot[
1]])
sub_plot[1]] = self.__set_connections_coords(
locs, self.connections_values
[sub_plot[0]][sub_plot[1]])

# Get color map and color connections
self.lines_cmap = get_colormap(cmap)
color, clim = self.__set_conn_color(sub_plot, threshold, clim)


_view_idx = self.__calculate_subplot_idx(sub_plot)
self.lines[sub_plot[0]][sub_plot[1]] = scene.Line(antialias=True,
parent=self.views[
_view_idx].scene,
)
self.lines[sub_plot[0]][sub_plot[1]] = scene.Line(
antialias=True,
parent=self.views[_view_idx].scene,)
self.lines[sub_plot[0]][sub_plot[1]].set_data(
pos=self.connections_coords_mutable[sub_plot[0]][sub_plot[1]],
color=color, width=4)
Expand Down Expand Up @@ -342,89 +398,37 @@ def update_connections(self, adj_mat, sub_plot, threshold):

self.lines[sub_plot[0]][sub_plot[1]].parent = None

self.lines[sub_plot[0]][sub_plot[1]] = scene.Line(antialias=True,
parent=self.views[
_view_idx].scene)
self.lines[sub_plot[0]][sub_plot[1]] = scene.Line(
antialias=True, parent=self.views[_view_idx].scene)
self.lines[sub_plot[0]][sub_plot[1]].set_data(
pos=self.connections_coords_mutable[sub_plot[0]][sub_plot[1]],
color=color, width=4, )
color=color, width=4,)
self.canvas.update()
except Exception as ex:
print(ex)

def __calculate_subplot_idx(self, sub_plot):
try:
_view_idx = sub_plot[1] + self.subplots[1] * sub_plot[0]
return _view_idx
except Exception as ex:
print(ex)

def __initialize_connections(self):
try:
subplots = self.subplots
if self.n_subplots is None:
subplots = (1, 1)
self.connections_coords_invariant = np.empty(
shape=subplots + (0,)).tolist()
self.connections_coords_mutable = np.empty(
shape=subplots + (0,)).tolist()
self.connections_values = np.empty(shape=subplots + (0,)).tolist()
self.lines = np.empty(shape=subplots + (0,)).tolist()
except Exception as ex:
print(ex)

def __init_markers(self):
try:
subplots = self.subplots
self.markers = np.empty(shape=subplots + (0,)).tolist()
except Exception as ex:
print(ex)

def __init_labels(self):
try:
subplots = self.subplots
self.labels_text = np.empty(shape=subplots + (0,)).tolist()
except Exception as ex:
print(ex)

@staticmethod
def __extract_conn_values(adj_mat):
try:
values = np.triu(adj_mat, 1)
values = values[np.where(values != 0)]
return values
except Exception as ex:
print(ex)

@staticmethod
def __set_connections_coords(locs, connections_values):
try:
connections_coords = np.empty((len(connections_values), 2, 3))
value_idx = 0
for i in range(len(locs)):
for j in range(i + 1, len(locs)):
connections_coords[value_idx, 0, 0] = locs[i, 0]
connections_coords[value_idx, 0, 1] = locs[i, 1]
connections_coords[value_idx, 0, 2] = locs[i, 2]
connections_coords[value_idx, 1, 0] = locs[j, 0]
connections_coords[value_idx, 1, 1] = locs[j, 1]
connections_coords[value_idx, 1, 2] = locs[j, 2]
value_idx += 1
return connections_coords
except Exception as ex:
print(ex)

def set_activation_map(self, adj_mat, locs, sub_plot=None, threshold=0.5,
plot_markers=True, labels=None, plot_labels=False,
cmap='seismic', clim=None, cbar=False):
"""Use this function to plot an activation map over the brain's
surface
"""

pass


if __name__ == '__main__':

from vispy.app import use_app

app = use_app("pyqt5")
app.create()
# Set canvas
triplot = TridimentionalBrain( bg_color='white',
text_color='black',translucent=[True,True],subplots=(1,2), names=['Prueba1','Prueba2'])
triplot = TridimentionalBrain(bg_color='white',
text_color='black',
translucent=[True, True],
subplots=(1, 2),
names=['Prueba1', 'Prueba2'])

# Define channel set and its coord
channel_set = EEGChannelSet(dim='3D', coord_system='cartesian')
Expand All @@ -437,10 +441,23 @@ def __set_connections_coords(locs, connections_values):
adj_mat = np.random.randn(len(channel_set.channels),
len(channel_set.channels))
adj_mat = 2 * adj_mat - 1
triplot.set_connections(adj_mat, channel_coord, threshold=[0.6], plot_labels=True,
labels=channel_set.l_cha, plot_markers=True,cmap='Spectral',sub_plot=(0,0))
triplot.set_connections(adj_mat, channel_coord, threshold=[0.6], plot_labels=False,
labels=channel_set.l_cha, plot_markers=True,cmap='Spectral',sub_plot=(0,1))
triplot.add_brains()
triplot.set_connections(adj_mat,
channel_coord,
threshold=[0.6],
plot_labels=True,
labels=channel_set.l_cha,
plot_markers=True,
cmap='Spectral',
sub_plot=(0, 0))

triplot.set_connections(adj_mat,
channel_coord,
threshold=[0.6],
plot_labels=False,
labels=channel_set.l_cha,
plot_markers=True,
cmap='Spectral',
sub_plot=(0, 1))
triplot.__add_brains()
triplot.canvas.show()
app.run()
Loading

0 comments on commit ecc66a0

Please sign in to comment.