Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 134 additions & 43 deletions tests/test_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
one_to_all_connectivity_pairs,
read_connectivity,
restrict_forward_to_vertices,
dics_coherence_external
)
from conpy.connectivity import _BaseConnectivity, _get_vert_ind_from_label
from mne import BiHemiLabel, Label, SourceEstimate
from mne.beamformer import make_dics
from mne.datasets import testing
from mne.time_frequency import csd_morlet
from mne.utils import _TempDir
Expand Down Expand Up @@ -60,7 +62,7 @@ def _load_restricted_forward(source_vertno1, source_vertno2):
return fwd_free, fwd_fixed


def _simulate_data(fwd_fixed, source_vertno1, source_vertno2):
def _simulate_data(fwd_fixed, source_vertno1, source_vertno2, external=False):
"""Simulate two oscillators on the cortex."""
sfreq = 50.0 # Hz.
base_freq = 10
Expand All @@ -86,18 +88,32 @@ def _simulate_data(fwd_fixed, source_vertno1, source_vertno2):
signal2 += 1e-8 * np.random.randn(len(times))

# Construct a SourceEstimate object
if external:
source_data = signal1[np.newaxis, :]
vertices = [np.array([source_vertno1]), np.array([])]
# Create an info object that holds information about the sensors
info = mne.create_info(
fwd_fixed["info"]["ch_names"] + ["external"],
sfreq,
ch_types=["grad"] * fwd_fixed["info"]["nchan"] + ["misc"],
)
else:
source_data = np.vstack((signal1[np.newaxis, :], signal2[np.newaxis, :]))
vertices = [np.array([source_vertno1]), np.array([source_vertno2])]
# Create an info object that holds information about the sensors
info = mne.create_info(fwd_fixed["info"]["ch_names"], sfreq,
ch_types="grad")

stc = mne.SourceEstimate(
np.vstack((signal1[np.newaxis, :], signal2[np.newaxis, :])),
vertices=[np.array([source_vertno1]), np.array([source_vertno2])],
source_data,
vertices=vertices,
tmin=0,
tstep=1 / sfreq,
subject="sample",
)

# Create an info object that holds information about the sensors
info = mne.create_info(fwd_fixed["info"]["ch_names"], sfreq, ch_types="grad")
with info._unlock():
info.update(fwd_fixed["info"]) # Merge in sensor position information
# Merge in sensor position information
for info_ch, fwd_ch in zip(info["chs"], fwd_fixed["info"]["chs"]):
info_ch.update(fwd_ch)

# Simulated sensor data.
raw = mne.apply_forward_raw(fwd_fixed, stc, info)
Expand All @@ -106,6 +122,11 @@ def _simulate_data(fwd_fixed, source_vertno1, source_vertno2):
noise = random.randn(*raw._data.shape) * 1e-14
raw._data += noise

if external:
sensor_data = raw.get_data()
sensor_data = np.vstack((sensor_data, np.atleast_2d(signal2)))
raw = mne.io.RawArray(sensor_data, info)

# Define a single epoch
epochs = mne.Epochs(
raw,
Expand All @@ -118,7 +139,7 @@ def _simulate_data(fwd_fixed, source_vertno1, source_vertno2):
)

# Compute the cross-spectral density matrix
csd = csd_morlet(epochs, frequencies=[10, 20])
csd = csd_morlet(epochs, picks=['meg', 'misc'], frequencies=[10, 20])

return csd

Expand Down Expand Up @@ -173,43 +194,43 @@ def _generate_labels(vertices, n_labels):
def test_base_connectivity():
"""Test construction of BaseConnectivity."""
# Pairs and data shape don't match
baseCon = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 4)
assert_array_equal(baseCon.data, [0.5, 0.5, 0.5])
base_con = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 4)
assert_array_equal(base_con.data, [0.5, 0.5, 0.5])

with pytest.raises(ValueError):
baseCon = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1], [2, 3]], 3)
base_con = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1], [2, 3]], 3)

# Not enough sources
with pytest.raises(ValueError):
baseCon = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 2)
base_con = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 2)

# Incorrecly shaped source degree
with pytest.raises(ValueError):
baseCon = _BaseConnectivity(
base_con = _BaseConnectivity(
[0.5, 0.5, 0.5],
[[1, 1, 2], [2, 3, 3]],
4,
source_degree=([0, 2, 1], [0, 0, 1]),
)

baseCon = _make_base_connectivity()
base_con = _make_base_connectivity()
assert_array_equal(
baseCon.source_degree,
base_con.source_degree,
np.array([[0, 2, 2, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 2, 0, 0, 0, 0, 0]]),
)

# Test properties
assert baseCon.n_connections == 5
assert base_con.n_connections == 5
state = {
"data": np.array([1, 1, 1]),
"pairs": [[2, 2, 3], [3, 4, 4]],
"n_sources": 5,
"subject": None,
"directed": False,
}
baseCon.__setstate__(state)
assert baseCon.n_sources == 5
assert_array_equal(baseCon.source_degree[0], [0, 0, 2, 1, 0])
base_con.__setstate__(state)
assert base_con.n_sources == 5
assert_array_equal(base_con.source_degree[0], [0, 0, 2, 1, 0])


def test_alltoall_connectivity():
Expand Down Expand Up @@ -269,8 +290,8 @@ def test_label_connnectivity():

def test_connectivity_repr():
"""Test string representation of connectivity classes."""
baseCon = _make_base_connectivity()
assert str(baseCon) == (
base_con = _make_base_connectivity()
assert str(base_con) == (
"<_BaseConnectivity | n_sources=10, n_conns=5," " subject=None>"
)

Expand Down Expand Up @@ -334,58 +355,58 @@ def test_connectivity_save():

def test_adjacency():
"""Test adjacency matrix."""
baseCon = _make_base_connectivity()
adjmat = baseCon.get_adjacency()
assert adjmat.nnz == 2 * baseCon.n_connections
assert adjmat.shape == (baseCon.n_sources, baseCon.n_sources)
base_con = _make_base_connectivity()
adjmat = base_con.get_adjacency()
assert adjmat.nnz == 2 * base_con.n_connections
assert adjmat.shape == (base_con.n_sources, base_con.n_sources)

# Directed
baseCon.directed = True
adjmat = baseCon.get_adjacency()
assert adjmat.nnz == baseCon.n_connections
assert adjmat.shape == (baseCon.n_sources, baseCon.n_sources)
base_con.directed = True
adjmat = base_con.get_adjacency()
assert adjmat.nnz == base_con.n_connections
assert adjmat.shape == (base_con.n_sources, base_con.n_sources)


def test_connectivity_threshold():
"""Test thresholding function of BaseConnectivity."""
# Criterion = None
baseCon = _make_base_connectivity()
threshCon = baseCon.threshold(2, copy=True)
base_con = _make_base_connectivity()
threshCon = base_con.threshold(2, copy=True)
assert threshCon.n_connections == 3
assert baseCon.n_connections == 5
assert base_con.n_connections == 5
assert_array_equal(threshCon.data, np.array([3, 4, 5]))

threshCon = baseCon.copy()
threshCon = base_con.copy()
threshCon.threshold(2, direction="below", copy=False)
assert threshCon.n_connections == 1
assert_array_equal(threshCon.data, np.array([1]))

# Incorrect direction
with pytest.raises(ValueError):
baseCon.threshold(1, direction="wrong")
base_con.threshold(1, direction="wrong")

# Use criterion
with pytest.raises(ValueError):
baseCon.threshold(1, crit=np.array([0, 0, 1, 2]))
base_con.threshold(1, crit=np.array([0, 0, 1, 2]))

pval = np.array([0.04, 0.7, 0.001, 0.1, 0.06])
threshCon = baseCon.threshold(0.05, crit=pval, copy=True)
threshCon = base_con.threshold(0.05, crit=pval, copy=True)
assert_array_equal(threshCon.data, np.array([2, 4, 5]))


def test_compatibility():
"""Test _iscombatible function."""
baseCon = _make_base_connectivity()
base_con = _make_base_connectivity()
all_con = _make_alltoall_connectivity()
label_con = _make_label_connectivity()

# Test BaseConnectivity
assert not baseCon.is_compatible(label_con)
assert not baseCon.is_compatible(all_con)
baseCon2 = _BaseConnectivity(
np.array([6, 7, 8, 9, 10]), baseCon.pairs, baseCon.n_sources
assert not base_con.is_compatible(label_con)
assert not base_con.is_compatible(all_con)
base_con2 = _BaseConnectivity(
np.array([6, 7, 8, 9, 10]), base_con.pairs, base_con.n_sources
)
assert baseCon.is_compatible(baseCon2)
assert base_con.is_compatible(base_con2)

# Test VertexConnectivity
assert not all_con.is_compatible(label_con)
Expand Down Expand Up @@ -655,3 +676,73 @@ def test_dics_connectivity():
fwd_tan = forward_to_tangential(fwd)
con2 = dics_connectivity(pairs, fwd_tan, csd, reg=1)
assert_array_equal(con2.data, con.data)


@testing.requires_testing_data
def test_dics_coherence_external():
"""Test dics_coherence_external function."""
fwd = _load_forward()
fwd = mne.pick_types_forward(fwd, meg="grad", eeg=False)
fwd_fixed = mne.convert_forward_solution(fwd, force_fixed=True)
source_vert = 146374
sfreq = 50.0
csd = _simulate_data(fwd_fixed, source_vert, None, external=True)
print(csd.ch_names)
# Create an info object that holds information about the sensors (their
# location, etc.). Make sure to include the external sensor!
info = mne.create_info(
fwd["info"]["ch_names"] + ["external"],
sfreq,
ch_types=["grad"] * fwd["info"]["nchan"] + ["misc"],
)
# Copy grad positions from the forward solution
for info_ch, fwd_ch in zip(info["chs"], fwd["info"]["chs"]):
info_ch.update(fwd_ch)

dics = make_dics(info.copy(), fwd, csd.copy(), reg=1,
inversion="single",
pick_ori=None)
dics_matrix = make_dics(info.copy(), fwd, csd.copy(), reg=1,
inversion="matrix",
pick_ori=None)
dics_fixed = make_dics(info.copy(), fwd, csd.copy(), reg=1,
inversion="single",
pick_ori='max-power')
# Tangential source space
with pytest.raises(ValueError):
dics_tangential = dics.copy()
dics_tangential["weights"] = np.delete(
dics_tangential["weights"],
np.arange(0, dics["weights"].shape[1], 3), axis=1)
dics_coherence_external(csd, dics_tangential, info, fwd,
external='external',
pick_ori='max-coherence')
# Max-power ori selected but dics has vector ori
with pytest.raises(ValueError):
dics_coherence_external(csd, dics, info, fwd, external='external',
pick_ori='max-power')
# Max-coherence ori selected but dics has fixed ori
with pytest.raises(ValueError):
dics_coherence_external(csd, dics_fixed, info, fwd, external='external',
pick_ori='max-coherence')
# Max-coherence ori selected but dics was computed using matrix inversion
with pytest.raises(ValueError):
dics_coherence_external(csd, dics_matrix, info, fwd,
external='external', pick_ori='max-coherence')
# Incorrect pick_ori
with pytest.raises(ValueError):
dics_coherence_external(csd, dics, info, fwd, external='external',
pick_ori='normal')
# Max power orientation
coh_stc = dics_coherence_external(csd, dics_fixed, info, fwd,
external='external', pick_ori='max-power')
assert isinstance(coh_stc, SourceEstimate)
assert coh_stc.data.shape == (fwd['nsource'], 2) # 2 frequencies
assert np.max(coh_stc.data) <= 1 and np.min(coh_stc.data) >= 0
# Max coherence orientation
coh_stc = dics_coherence_external(csd, dics, info, fwd,
external='external',
pick_ori='max-coherence')
assert isinstance(coh_stc, SourceEstimate)
assert coh_stc.data.shape == (fwd['nsource'], 2) # 2 frequencies
assert np.max(coh_stc.data) <= 1 and np.min(coh_stc.data) >= 0