Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 41 additions & 34 deletions conpy/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def select_vertices_in_sensor_range(
"You need to specify an Info object with "
"information about the channels."
)
n_src = len(src)

# Load the head<->MRI transform if necessary
if src[0]["coord_frame"] == FIFF.FIFFV_COORD_MRI:
Expand Down Expand Up @@ -149,11 +150,10 @@ def select_vertices_in_sensor_range(
if indices:
return np.flatnonzero(src_sel)
else:
n_lh_verts = src[0]["nuse"]
lh_sel, rh_sel = src_sel[:n_lh_verts], src_sel[n_lh_verts:]
vert_lh = src[0]["vertno"][lh_sel]
vert_rh = src[1]["vertno"][rh_sel]
return [vert_lh, vert_rh]
n_verts = np.cumsum([0] + [s["nuse"] for s in src])
sel = [src_sel[n_verts[i]:n_verts[i+1]] for i in range(n_src)]
verts = [src[i]["vertno"][sel[i]] for i in range(n_src)]
return verts


@verbose
Expand Down Expand Up @@ -201,34 +201,38 @@ def restrict_forward_to_vertices(
else:
fwd_out = fwd

lh_vertno, rh_vertno = [src["vertno"] for src in fwd["src"]]

n_src = len(fwd["src"])
vertno = [s["vertno"] for s in fwd["src"]]
n_vertno = [len(hemi_vertno) for hemi_vertno in vertno]
if isinstance(vertno_or_idx[0], int):
logger.info("Interpreting given vertno_or_idx as vertex indices.")
vertno_or_idx = np.asarray(vertno_or_idx)

# Make sure the vertices are in sequential order
fwd_idx = np.sort(vertno_or_idx)

n_vert_lh = len(lh_vertno)
sel_lh_idx = vertno_or_idx[fwd_idx < n_vert_lh]
sel_rh_idx = vertno_or_idx[fwd_idx >= n_vert_lh] - n_vert_lh
sel_lh_vertno = lh_vertno[sel_lh_idx]
sel_rh_vertno = rh_vertno[sel_rh_idx]
vert_idx = np.cumsum([0] + n_vertno)
sel_idx = [
vertno_or_idx[(fwd_idx >= vert_idx[i])
& (fwd_idx < vert_idx[i+1])] - vert_idx[i]
for i in range(n_src)]
sel_vertno = [hemi_vertno[sel] for hemi_vertno, sel in zip(vertno, sel_idx)]
else:
logger.info("Interpreting given vertno_or_idx as vertex numbers.")

# Make sure vertno_or_idx is sorted
vertno_or_idx = [np.sort(v) for v in vertno_or_idx]
sel_vertno = vertno_or_idx

sel_lh_vertno, sel_rh_vertno = vertno_or_idx
src_lh_idx = _find_indices_1d(lh_vertno, sel_lh_vertno, check_vertno)
src_rh_idx = _find_indices_1d(rh_vertno, sel_rh_vertno, check_vertno)
fwd_idx = np.hstack((src_lh_idx, src_rh_idx + len(lh_vertno)))
src_idx = [
_find_indices_1d(hemi_vertno, sel, check_vertno) + sum(n_vertno[:i])
for i, (hemi_vertno, sel) in enumerate(zip(vertno, sel_vertno))
]
fwd_idx = np.hstack(src_idx)

logger.info(
"Restricting forward solution to %d out of %d vertices."
% (len(fwd_idx), len(lh_vertno) + len(rh_vertno))
% (len(fwd_idx), sum(n_vertno))
)

n_orient = fwd["sol"]["ncol"] // fwd["nsource"]
Expand Down Expand Up @@ -260,7 +264,7 @@ def _reshape_select(X, dim3, sel):
# Restrict the SourceSpaces inside the forward operator
fwd_out["src"] = restrict_src_to_vertices(
fwd_out["src"],
[sel_lh_vertno, sel_rh_vertno],
sel_vertno,
check_vertno=False,
verbose=False,
)
Expand Down Expand Up @@ -307,37 +311,40 @@ def restrict_src_to_vertices(
else:
src_out = src

n_src = len(src)
if vertno_or_idx:
if isinstance(vertno_or_idx[0], int):
logger.info("Interpreting given vertno_or_idx as vertex indices.")
vertno_or_idx = np.asarray(vertno_or_idx)
n_vert_lh = src[0]["nuse"]
ind_lh = vertno_or_idx[vertno_or_idx < n_vert_lh]
ind_rh = vertno_or_idx[vertno_or_idx >= n_vert_lh] - n_vert_lh
vert_no_lh = src[0]["vertno"][ind_lh]
vert_no_rh = src[1]["vertno"][ind_rh]
vert_idx = np.cumsum([0] + [s["nuse"] for s in src])
ind = [
vertno_or_idx[
(vertno_or_idx >= vert_idx[i]) & (vertno_or_idx < vert_idx[i+1])
] - vert_idx[i] for i in range(n_src)
]
vertno = [s["vertno"][inds] for s, inds in zip(src, ind)]
else:
logger.info("Interpreting given vertno_or_idx as vertex numbers.")
vert_no_lh, vert_no_rh = vertno_or_idx
vertno = vertno_or_idx
if check_vertno:
if not (
np.all(np.isin(vert_no_lh, src[0]["vertno"]))
and np.all(np.isin(vert_no_rh, src[1]["vertno"]))
):
raise ValueError(
"One or more vertices were not present in SourceSpaces."
)
for s, verts in zip(src, vertno):
if not np.all(np.isin(verts, s["vertno"])):
raise ValueError(
"One or more vertices were not present in SourceSpaces."
)

else:
# Empty list
vert_no_lh, vert_no_rh = [], []
vertno = [[] for i in range(n_src)]

nuse = sum([s["nuse"] for s in src])
n_vertno = sum([len(verts) for verts in vertno])
logger.info(
"Restricting source space to %d out of %d vertices."
% (len(vert_no_lh) + len(vert_no_rh), src[0]["nuse"] + src[1]["nuse"])
% (n_vertno, nuse)
)

for hemi, verts in zip(src_out, (vert_no_lh, vert_no_rh)):
for hemi, verts in zip(src_out, vertno):
# Ensure vertices are in sequential order
verts = np.sort(verts)

Expand Down
36 changes: 36 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ def src():
)


@pytest.fixture
def vol_src():
"""Load a volume source space."""
path = mne.datasets.sample.data_path()
return mne.read_source_spaces(
op.join(path, "subjects", "sample", "bem", "volume-7mm-src.fif")
)


@pytest.fixture
def vol_fwd():
"""Load a volume forward solution."""
path = mne.datasets.sample.data_path()
return mne.read_forward_solution(
op.join(path, "MEG", "sample", "sample_audvis-meg-vol-7-fwd.fif")
)


def _trans():
path = mne.datasets.sample.data_path()
return op.join(path, "MEG", "sample", "sample_audvis_raw-trans.fif")
Expand Down Expand Up @@ -213,6 +231,24 @@ def test_select_vertices_in_sensor_range(fwd, src):
assert_array_equal(verts2[1], np.array([2159]))


def test_select_vertices_in_sensor_range_volume(vol_fwd):
"""Test selecting vertices in sensor range with volumetric source space."""
fwd_r = restrict_forward_to_vertices(vol_fwd, ([[1273, 1312]]))
assert_array_equal(fwd_r["src"][0]["vertno"], np.array([1273, 1312]))

verts = select_vertices_in_sensor_range(fwd_r, 0.08)
assert_array_equal(verts[0], np.array([1273]))

# Test indices
verts = select_vertices_in_sensor_range(fwd_r, 0.08, indices=True)
assert_array_equal(verts, np.array([0]))

# Test restricting
fwd_rs = restrict_forward_to_sensor_range(fwd_r, 0.08)
assert_array_equal(fwd_rs["src"][0]["vertno"], np.array([1273]))
assert len(fwd_rs["src"]) == 1 # No second source space


# FIXME: disabled until we can make a proper test
# def test_radial_coord_system():
# """Test making a radial coordinate system."""
Expand Down