diff --git a/tests/test_connectivity.py b/tests/test_connectivity.py index 3dafa32..997a24d 100644 --- a/tests/test_connectivity.py +++ b/tests/test_connectivity.py @@ -15,9 +15,11 @@ one_to_all_connectivity_pairs, read_connectivity, restrict_forward_to_vertices, + dics_coherence_external ) from conpy.connectivity import _BaseConnectivity, _get_vert_ind_from_label from mne import BiHemiLabel, Label, SourceEstimate +from mne.beamformer import make_dics from mne.datasets import testing from mne.time_frequency import csd_morlet from mne.utils import _TempDir @@ -60,7 +62,7 @@ def _load_restricted_forward(source_vertno1, source_vertno2): return fwd_free, fwd_fixed -def _simulate_data(fwd_fixed, source_vertno1, source_vertno2): +def _simulate_data(fwd_fixed, source_vertno1, source_vertno2, external=False): """Simulate two oscillators on the cortex.""" sfreq = 50.0 # Hz. base_freq = 10 @@ -86,18 +88,32 @@ def _simulate_data(fwd_fixed, source_vertno1, source_vertno2): signal2 += 1e-8 * np.random.randn(len(times)) # Construct a SourceEstimate object + if external: + source_data = signal1[np.newaxis, :] + vertices = [np.array([source_vertno1]), np.array([])] + # Create an info object that holds information about the sensors + info = mne.create_info( + fwd_fixed["info"]["ch_names"] + ["external"], + sfreq, + ch_types=["grad"] * fwd_fixed["info"]["nchan"] + ["misc"], + ) + else: + source_data = np.vstack((signal1[np.newaxis, :], signal2[np.newaxis, :])) + vertices = [np.array([source_vertno1]), np.array([source_vertno2])] + # Create an info object that holds information about the sensors + info = mne.create_info(fwd_fixed["info"]["ch_names"], sfreq, + ch_types="grad") + stc = mne.SourceEstimate( - np.vstack((signal1[np.newaxis, :], signal2[np.newaxis, :])), - vertices=[np.array([source_vertno1]), np.array([source_vertno2])], + source_data, + vertices=vertices, tmin=0, tstep=1 / sfreq, subject="sample", ) - - # Create an info object that holds information about the sensors - info = mne.create_info(fwd_fixed["info"]["ch_names"], sfreq, ch_types="grad") - with info._unlock(): - info.update(fwd_fixed["info"]) # Merge in sensor position information + # Merge in sensor position information + for info_ch, fwd_ch in zip(info["chs"], fwd_fixed["info"]["chs"]): + info_ch.update(fwd_ch) # Simulated sensor data. raw = mne.apply_forward_raw(fwd_fixed, stc, info) @@ -106,6 +122,11 @@ def _simulate_data(fwd_fixed, source_vertno1, source_vertno2): noise = random.randn(*raw._data.shape) * 1e-14 raw._data += noise + if external: + sensor_data = raw.get_data() + sensor_data = np.vstack((sensor_data, np.atleast_2d(signal2))) + raw = mne.io.RawArray(sensor_data, info) + # Define a single epoch epochs = mne.Epochs( raw, @@ -118,7 +139,7 @@ def _simulate_data(fwd_fixed, source_vertno1, source_vertno2): ) # Compute the cross-spectral density matrix - csd = csd_morlet(epochs, frequencies=[10, 20]) + csd = csd_morlet(epochs, picks=['meg', 'misc'], frequencies=[10, 20]) return csd @@ -173,33 +194,33 @@ def _generate_labels(vertices, n_labels): def test_base_connectivity(): """Test construction of BaseConnectivity.""" # Pairs and data shape don't match - baseCon = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 4) - assert_array_equal(baseCon.data, [0.5, 0.5, 0.5]) + base_con = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 4) + assert_array_equal(base_con.data, [0.5, 0.5, 0.5]) with pytest.raises(ValueError): - baseCon = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1], [2, 3]], 3) + base_con = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1], [2, 3]], 3) # Not enough sources with pytest.raises(ValueError): - baseCon = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 2) + base_con = _BaseConnectivity([0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 2) # Incorrecly shaped source degree with pytest.raises(ValueError): - baseCon = _BaseConnectivity( + base_con = _BaseConnectivity( [0.5, 0.5, 0.5], [[1, 1, 2], [2, 3, 3]], 4, source_degree=([0, 2, 1], [0, 0, 1]), ) - baseCon = _make_base_connectivity() + base_con = _make_base_connectivity() assert_array_equal( - baseCon.source_degree, + base_con.source_degree, np.array([[0, 2, 2, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 2, 0, 0, 0, 0, 0]]), ) # Test properties - assert baseCon.n_connections == 5 + assert base_con.n_connections == 5 state = { "data": np.array([1, 1, 1]), "pairs": [[2, 2, 3], [3, 4, 4]], @@ -207,9 +228,9 @@ def test_base_connectivity(): "subject": None, "directed": False, } - baseCon.__setstate__(state) - assert baseCon.n_sources == 5 - assert_array_equal(baseCon.source_degree[0], [0, 0, 2, 1, 0]) + base_con.__setstate__(state) + assert base_con.n_sources == 5 + assert_array_equal(base_con.source_degree[0], [0, 0, 2, 1, 0]) def test_alltoall_connectivity(): @@ -269,8 +290,8 @@ def test_label_connnectivity(): def test_connectivity_repr(): """Test string representation of connectivity classes.""" - baseCon = _make_base_connectivity() - assert str(baseCon) == ( + base_con = _make_base_connectivity() + assert str(base_con) == ( "<_BaseConnectivity | n_sources=10, n_conns=5," " subject=None>" ) @@ -334,58 +355,58 @@ def test_connectivity_save(): def test_adjacency(): """Test adjacency matrix.""" - baseCon = _make_base_connectivity() - adjmat = baseCon.get_adjacency() - assert adjmat.nnz == 2 * baseCon.n_connections - assert adjmat.shape == (baseCon.n_sources, baseCon.n_sources) + base_con = _make_base_connectivity() + adjmat = base_con.get_adjacency() + assert adjmat.nnz == 2 * base_con.n_connections + assert adjmat.shape == (base_con.n_sources, base_con.n_sources) # Directed - baseCon.directed = True - adjmat = baseCon.get_adjacency() - assert adjmat.nnz == baseCon.n_connections - assert adjmat.shape == (baseCon.n_sources, baseCon.n_sources) + base_con.directed = True + adjmat = base_con.get_adjacency() + assert adjmat.nnz == base_con.n_connections + assert adjmat.shape == (base_con.n_sources, base_con.n_sources) def test_connectivity_threshold(): """Test thresholding function of BaseConnectivity.""" # Criterion = None - baseCon = _make_base_connectivity() - threshCon = baseCon.threshold(2, copy=True) + base_con = _make_base_connectivity() + threshCon = base_con.threshold(2, copy=True) assert threshCon.n_connections == 3 - assert baseCon.n_connections == 5 + assert base_con.n_connections == 5 assert_array_equal(threshCon.data, np.array([3, 4, 5])) - threshCon = baseCon.copy() + threshCon = base_con.copy() threshCon.threshold(2, direction="below", copy=False) assert threshCon.n_connections == 1 assert_array_equal(threshCon.data, np.array([1])) # Incorrect direction with pytest.raises(ValueError): - baseCon.threshold(1, direction="wrong") + base_con.threshold(1, direction="wrong") # Use criterion with pytest.raises(ValueError): - baseCon.threshold(1, crit=np.array([0, 0, 1, 2])) + base_con.threshold(1, crit=np.array([0, 0, 1, 2])) pval = np.array([0.04, 0.7, 0.001, 0.1, 0.06]) - threshCon = baseCon.threshold(0.05, crit=pval, copy=True) + threshCon = base_con.threshold(0.05, crit=pval, copy=True) assert_array_equal(threshCon.data, np.array([2, 4, 5])) def test_compatibility(): """Test _iscombatible function.""" - baseCon = _make_base_connectivity() + base_con = _make_base_connectivity() all_con = _make_alltoall_connectivity() label_con = _make_label_connectivity() # Test BaseConnectivity - assert not baseCon.is_compatible(label_con) - assert not baseCon.is_compatible(all_con) - baseCon2 = _BaseConnectivity( - np.array([6, 7, 8, 9, 10]), baseCon.pairs, baseCon.n_sources + assert not base_con.is_compatible(label_con) + assert not base_con.is_compatible(all_con) + base_con2 = _BaseConnectivity( + np.array([6, 7, 8, 9, 10]), base_con.pairs, base_con.n_sources ) - assert baseCon.is_compatible(baseCon2) + assert base_con.is_compatible(base_con2) # Test VertexConnectivity assert not all_con.is_compatible(label_con) @@ -655,3 +676,73 @@ def test_dics_connectivity(): fwd_tan = forward_to_tangential(fwd) con2 = dics_connectivity(pairs, fwd_tan, csd, reg=1) assert_array_equal(con2.data, con.data) + + +@testing.requires_testing_data +def test_dics_coherence_external(): + """Test dics_coherence_external function.""" + fwd = _load_forward() + fwd = mne.pick_types_forward(fwd, meg="grad", eeg=False) + fwd_fixed = mne.convert_forward_solution(fwd, force_fixed=True) + source_vert = 146374 + sfreq = 50.0 + csd = _simulate_data(fwd_fixed, source_vert, None, external=True) + print(csd.ch_names) + # Create an info object that holds information about the sensors (their + # location, etc.). Make sure to include the external sensor! + info = mne.create_info( + fwd["info"]["ch_names"] + ["external"], + sfreq, + ch_types=["grad"] * fwd["info"]["nchan"] + ["misc"], + ) + # Copy grad positions from the forward solution + for info_ch, fwd_ch in zip(info["chs"], fwd["info"]["chs"]): + info_ch.update(fwd_ch) + + dics = make_dics(info.copy(), fwd, csd.copy(), reg=1, + inversion="single", + pick_ori=None) + dics_matrix = make_dics(info.copy(), fwd, csd.copy(), reg=1, + inversion="matrix", + pick_ori=None) + dics_fixed = make_dics(info.copy(), fwd, csd.copy(), reg=1, + inversion="single", + pick_ori='max-power') + # Tangential source space + with pytest.raises(ValueError): + dics_tangential = dics.copy() + dics_tangential["weights"] = np.delete( + dics_tangential["weights"], + np.arange(0, dics["weights"].shape[1], 3), axis=1) + dics_coherence_external(csd, dics_tangential, info, fwd, + external='external', + pick_ori='max-coherence') + # Max-power ori selected but dics has vector ori + with pytest.raises(ValueError): + dics_coherence_external(csd, dics, info, fwd, external='external', + pick_ori='max-power') + # Max-coherence ori selected but dics has fixed ori + with pytest.raises(ValueError): + dics_coherence_external(csd, dics_fixed, info, fwd, external='external', + pick_ori='max-coherence') + # Max-coherence ori selected but dics was computed using matrix inversion + with pytest.raises(ValueError): + dics_coherence_external(csd, dics_matrix, info, fwd, + external='external', pick_ori='max-coherence') + # Incorrect pick_ori + with pytest.raises(ValueError): + dics_coherence_external(csd, dics, info, fwd, external='external', + pick_ori='normal') + # Max power orientation + coh_stc = dics_coherence_external(csd, dics_fixed, info, fwd, + external='external', pick_ori='max-power') + assert isinstance(coh_stc, SourceEstimate) + assert coh_stc.data.shape == (fwd['nsource'], 2) # 2 frequencies + assert np.max(coh_stc.data) <= 1 and np.min(coh_stc.data) >= 0 + # Max coherence orientation + coh_stc = dics_coherence_external(csd, dics, info, fwd, + external='external', + pick_ori='max-coherence') + assert isinstance(coh_stc, SourceEstimate) + assert coh_stc.data.shape == (fwd['nsource'], 2) # 2 frequencies + assert np.max(coh_stc.data) <= 1 and np.min(coh_stc.data) >= 0