diff --git a/conpy/forward.py b/conpy/forward.py index 9cf8264..71f2788 100644 --- a/conpy/forward.py +++ b/conpy/forward.py @@ -84,6 +84,7 @@ def select_vertices_in_sensor_range( "You need to specify an Info object with " "information about the channels." ) + n_src = len(src) # Load the head<->MRI transform if necessary if src[0]["coord_frame"] == FIFF.FIFFV_COORD_MRI: @@ -149,11 +150,10 @@ def select_vertices_in_sensor_range( if indices: return np.flatnonzero(src_sel) else: - n_lh_verts = src[0]["nuse"] - lh_sel, rh_sel = src_sel[:n_lh_verts], src_sel[n_lh_verts:] - vert_lh = src[0]["vertno"][lh_sel] - vert_rh = src[1]["vertno"][rh_sel] - return [vert_lh, vert_rh] + n_verts = np.cumsum([0] + [s["nuse"] for s in src]) + sel = [src_sel[n_verts[i]:n_verts[i+1]] for i in range(n_src)] + verts = [src[i]["vertno"][sel[i]] for i in range(n_src)] + return verts @verbose @@ -201,8 +201,9 @@ def restrict_forward_to_vertices( else: fwd_out = fwd - lh_vertno, rh_vertno = [src["vertno"] for src in fwd["src"]] - + n_src = len(fwd["src"]) + vertno = [s["vertno"] for s in fwd["src"]] + n_vertno = [len(hemi_vertno) for hemi_vertno in vertno] if isinstance(vertno_or_idx[0], int): logger.info("Interpreting given vertno_or_idx as vertex indices.") vertno_or_idx = np.asarray(vertno_or_idx) @@ -210,25 +211,28 @@ def restrict_forward_to_vertices( # Make sure the vertices are in sequential order fwd_idx = np.sort(vertno_or_idx) - n_vert_lh = len(lh_vertno) - sel_lh_idx = vertno_or_idx[fwd_idx < n_vert_lh] - sel_rh_idx = vertno_or_idx[fwd_idx >= n_vert_lh] - n_vert_lh - sel_lh_vertno = lh_vertno[sel_lh_idx] - sel_rh_vertno = rh_vertno[sel_rh_idx] + vert_idx = np.cumsum([0] + n_vertno) + sel_idx = [ + vertno_or_idx[(fwd_idx >= vert_idx[i]) + & (fwd_idx < vert_idx[i+1])] - vert_idx[i] + for i in range(n_src)] + sel_vertno = [hemi_vertno[sel] for hemi_vertno, sel in zip(vertno, sel_idx)] else: logger.info("Interpreting given vertno_or_idx as vertex numbers.") # Make sure vertno_or_idx is sorted vertno_or_idx = [np.sort(v) for v in vertno_or_idx] + sel_vertno = vertno_or_idx - sel_lh_vertno, sel_rh_vertno = vertno_or_idx - src_lh_idx = _find_indices_1d(lh_vertno, sel_lh_vertno, check_vertno) - src_rh_idx = _find_indices_1d(rh_vertno, sel_rh_vertno, check_vertno) - fwd_idx = np.hstack((src_lh_idx, src_rh_idx + len(lh_vertno))) + src_idx = [ + _find_indices_1d(hemi_vertno, sel, check_vertno) + sum(n_vertno[:i]) + for i, (hemi_vertno, sel) in enumerate(zip(vertno, sel_vertno)) + ] + fwd_idx = np.hstack(src_idx) logger.info( "Restricting forward solution to %d out of %d vertices." - % (len(fwd_idx), len(lh_vertno) + len(rh_vertno)) + % (len(fwd_idx), sum(n_vertno)) ) n_orient = fwd["sol"]["ncol"] // fwd["nsource"] @@ -260,7 +264,7 @@ def _reshape_select(X, dim3, sel): # Restrict the SourceSpaces inside the forward operator fwd_out["src"] = restrict_src_to_vertices( fwd_out["src"], - [sel_lh_vertno, sel_rh_vertno], + sel_vertno, check_vertno=False, verbose=False, ) @@ -307,37 +311,40 @@ def restrict_src_to_vertices( else: src_out = src + n_src = len(src) if vertno_or_idx: if isinstance(vertno_or_idx[0], int): logger.info("Interpreting given vertno_or_idx as vertex indices.") vertno_or_idx = np.asarray(vertno_or_idx) - n_vert_lh = src[0]["nuse"] - ind_lh = vertno_or_idx[vertno_or_idx < n_vert_lh] - ind_rh = vertno_or_idx[vertno_or_idx >= n_vert_lh] - n_vert_lh - vert_no_lh = src[0]["vertno"][ind_lh] - vert_no_rh = src[1]["vertno"][ind_rh] + vert_idx = np.cumsum([0] + [s["nuse"] for s in src]) + ind = [ + vertno_or_idx[ + (vertno_or_idx >= vert_idx[i]) & (vertno_or_idx < vert_idx[i+1]) + ] - vert_idx[i] for i in range(n_src) + ] + vertno = [s["vertno"][inds] for s, inds in zip(src, ind)] else: logger.info("Interpreting given vertno_or_idx as vertex numbers.") - vert_no_lh, vert_no_rh = vertno_or_idx + vertno = vertno_or_idx if check_vertno: - if not ( - np.all(np.isin(vert_no_lh, src[0]["vertno"])) - and np.all(np.isin(vert_no_rh, src[1]["vertno"])) - ): - raise ValueError( - "One or more vertices were not present in SourceSpaces." - ) + for s, verts in zip(src, vertno): + if not np.all(np.isin(verts, s["vertno"])): + raise ValueError( + "One or more vertices were not present in SourceSpaces." + ) else: # Empty list - vert_no_lh, vert_no_rh = [], [] + vertno = [[] for i in range(n_src)] + nuse = sum([s["nuse"] for s in src]) + n_vertno = sum([len(verts) for verts in vertno]) logger.info( "Restricting source space to %d out of %d vertices." - % (len(vert_no_lh) + len(vert_no_rh), src[0]["nuse"] + src[1]["nuse"]) + % (n_vertno, nuse) ) - for hemi, verts in zip(src_out, (vert_no_lh, vert_no_rh)): + for hemi, verts in zip(src_out, vertno): # Ensure vertices are in sequential order verts = np.sort(verts) diff --git a/tests/test_forward.py b/tests/test_forward.py index 340bb3c..28ddf5c 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -33,6 +33,24 @@ def src(): ) +@pytest.fixture +def vol_src(): + """Load a volume source space.""" + path = mne.datasets.sample.data_path() + return mne.read_source_spaces( + op.join(path, "subjects", "sample", "bem", "volume-7mm-src.fif") + ) + + +@pytest.fixture +def vol_fwd(): + """Load a volume forward solution.""" + path = mne.datasets.sample.data_path() + return mne.read_forward_solution( + op.join(path, "MEG", "sample", "sample_audvis-meg-vol-7-fwd.fif") + ) + + def _trans(): path = mne.datasets.sample.data_path() return op.join(path, "MEG", "sample", "sample_audvis_raw-trans.fif") @@ -213,6 +231,24 @@ def test_select_vertices_in_sensor_range(fwd, src): assert_array_equal(verts2[1], np.array([2159])) +def test_select_vertices_in_sensor_range_volume(vol_fwd): + """Test selecting vertices in sensor range with volumetric source space.""" + fwd_r = restrict_forward_to_vertices(vol_fwd, ([[1273, 1312]])) + assert_array_equal(fwd_r["src"][0]["vertno"], np.array([1273, 1312])) + + verts = select_vertices_in_sensor_range(fwd_r, 0.08) + assert_array_equal(verts[0], np.array([1273])) + + # Test indices + verts = select_vertices_in_sensor_range(fwd_r, 0.08, indices=True) + assert_array_equal(verts, np.array([0])) + + # Test restricting + fwd_rs = restrict_forward_to_sensor_range(fwd_r, 0.08) + assert_array_equal(fwd_rs["src"][0]["vertno"], np.array([1273])) + assert len(fwd_rs["src"]) == 1 # No second source space + + # FIXME: disabled until we can make a proper test # def test_radial_coord_system(): # """Test making a radial coordinate system."""