Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13525.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :func:`mne.chpi.refit_hpi` did not take ``gof_limit`` into account when fitting HPI order, by `Eric Larson`_
4 changes: 2 additions & 2 deletions doc/sphinxext/directive_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check_directive_formatting(*args):
# another directive/another directive's content)
if idx == 0:
continue
dir_pattern = r"\.\. [a-zA-Z]+::"
dir_pattern = r"^\s*\.\. \w+::" # line might start with whitespace
head_pattern = r"^[-|=|\^]+$"
directive = re.search(dir_pattern, line)
if directive is not None:
Expand All @@ -84,5 +84,5 @@ def check_directive_formatting(*args):
if bad:
sphinx_logger.warning(
f"{source_type} '{name}' is missing a blank line before the "
f"directive '{directive.group()}'"
f"directive '{directive.group()}' on line {idx + 1}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
=========================================================================

This example runs the analysis described in :footcite:`KingDehaene2014`. It
illustrates how one can
fit a linear classifier to identify a discriminatory topography at a given time
instant and subsequently assess whether this linear model can accurately
predict all of the time samples of a second set of conditions.
illustrates how one can fit a linear classifier to identify a discriminatory
topography at a given time instant and subsequently assess whether this linear
model can accurately predict all of the time samples of a second set of conditions.
"""
# Authors: Jean-Rémi King <[email protected]>
# Alexandre Gramfort <[email protected]>
Expand Down
78 changes: 58 additions & 20 deletions mne/chpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,27 +579,37 @@ def _chpi_objective(x, coil_dev_rrs, coil_head_rrs):
return d.sum()


def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs):
def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs, *, quat=None):
"""Fit rotation and translation (quaternion) parameters for cHPI coils."""
denom = np.linalg.norm(coil_head_rrs - np.mean(coil_head_rrs, axis=0))
denom *= denom
# We could try to solve it the analytic way:
# TODO someday we could choose to weight these points by their goodness
# of fit somehow, see also https://github.com/mne-tools/mne-python/issues/11330
quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0]
if quat is None:
quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0]
gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom
return quat, gof


def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix=""):
def _fit_coil_order_dev_head_trans(
dev_pnts, head_pnts, *, bias=True, gofs=None, gof_limit=0.98, prefix=""
):
"""Compute Device to Head transform allowing for permutiatons of points."""
n_coils = len(dev_pnts)
id_quat = np.zeros(6)
best_order = None
best_order = np.full(n_coils, -1, dtype=int)
best_g = -999
best_quat = id_quat
for this_order in itertools.permutations(np.arange(len(head_pnts))):
assert dev_pnts.shape == head_pnts.shape == (n_coils, 3)
gofs = np.ones(n_coils) if gofs is None else gofs
use_mask = _gof_use_mask(gofs, gof_limit=gof_limit)
n_use = int(use_mask.sum()) # explicit int cast for itertools.permutations
dev_pnts_tmp = dev_pnts[use_mask]
# First pass: figure out best order using the good dev points
for this_order in itertools.permutations(np.arange(len(head_pnts)), n_use):
head_pnts_tmp = head_pnts[np.array(this_order)]
this_quat, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp)
this_quat, g = _fit_chpi_quat(dev_pnts_tmp, head_pnts_tmp)
assert np.linalg.det(quat_to_rot(this_quat[:3])) > 0.9999
if bias:
# For symmetrical arrangements, flips can produce roughly
Expand All @@ -612,17 +622,35 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix="")
if check_g > best_g:
out_g = g
best_g = check_g
best_order = np.array(this_order)
best_order[use_mask] = this_order
best_quat = this_quat
del this_order
# Second pass: now fit the remaining (bad) coils using the best order and quat
# from above
missing = np.setdiff1d(np.arange(n_coils), best_order[best_order >= 0])
best_missing_g = -np.inf
for this_order in itertools.permutations(missing):
full_order = best_order.copy()
full_order[~use_mask] = this_order
assert (full_order >= 0).all()
assert np.array_equal(np.sort(full_order), np.arange(n_coils))
head_pnts_tmp = head_pnts[np.array(full_order)]
_, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp, quat=best_quat)
if g > best_missing_g:
best_missing_g = g
best_order[:] = full_order
del this_order
assert np.array_equal(np.sort(best_order), np.arange(n_coils))

# Convert Quaterion to transform
dev_head_t = _quat_to_affine(best_quat)
ang, dist = angle_distance_between_rigid(
dev_head_t, angle_units="deg", distance_units="mm"
)
extra = f" using {n_use}/{n_coils} coils" if n_use < n_coils else ""
logger.info(
f"{prefix}Fitted dev_head_t {ang:0.1f}° and {dist:0.1f} mm "
f"from device origin (GOF: {out_g:.3f})"
f"from device origin{extra} (GOF: {out_g:.3f})"
)
return dev_head_t, best_order, out_g

Expand Down Expand Up @@ -1703,7 +1731,8 @@ def refit_hpi(
:func:`~mne.chpi.compute_chpi_locs`.
3. Optionally determine coil digitization order by testing all permutations
for the best goodness of fit between digitized coil locations and
(rigid-transformed) fitted coil locations.
(rigid-transformed) fitted coil locations, choosing the order first based on
those that satisfy ``gof_limit`` then the others.
4. Subselect coils to use for fitting ``dev_head_t`` based on ``gof_limit``,
``dist_limit``, and ``use``.
5. Update info inplace by modifying ``info["dev_head_t"]`` and appending new entries
Expand Down Expand Up @@ -1816,6 +1845,8 @@ def refit_hpi(
fit_dev_head_t, fit_order, _g = _fit_coil_order_dev_head_trans(
hpi_dev,
hpi_head,
gofs=hpi_gofs,
gof_limit=gof_limit,
prefix=" ",
)
else:
Expand All @@ -1824,27 +1855,21 @@ def refit_hpi(

# 4. Subselect usable coils and determine final dev_head_t
if isinstance(use, int) or use is None:
used = np.where(hpi_gofs >= gof_limit)[0]
if len(used) < 3:
gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs)
raise RuntimeError(
f"Only {len(used)} coil{_pl(used)} with goodness of fit >= {gof_limit}"
f", need at least 3 to refit HPI order (got {gofs})."
)
quat, _g = _fit_chpi_quat(hpi_dev[used], hpi_head[fit_order][used])
use_mask = _gof_use_mask(hpi_gofs, gof_limit=gof_limit)
quat, _g = _fit_chpi_quat(hpi_dev[use_mask], hpi_head[fit_order][use_mask])
fit_dev_head_t = _quat_to_affine(quat)
hpi_head_got = apply_trans(fit_dev_head_t, hpi_dev)
dists = np.linalg.norm(hpi_head_got - hpi_head[fit_order], axis=1)
dist_str = " ".join(f"{dist * 1e3:.1f}" for dist in dists)
logger.info(f" Coil distances after initial fit: {dist_str} mm")
good_dists_idx = np.where(dists[used] <= dist_limit)[0]
good_dists_idx = np.where(dists[use_mask] <= dist_limit)[0]
if not len(good_dists_idx) >= 3:
raise RuntimeError(
f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} have distance "
f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} with distance "
f"<= {dist_limit * 1e3:.1f} mm, need at least 3 to refit HPI order "
f"(got distances: {np.round(1e3 * dists, 1)})."
)
used = used[good_dists_idx]
used = np.where(use_mask)[0][good_dists_idx]
if use is not None:
used = np.sort(used[np.argsort(hpi_gofs[used])[-use:]])
else:
Expand Down Expand Up @@ -1927,6 +1952,19 @@ def refit_hpi(
return info


def _gof_use_mask(hpi_gofs, *, gof_limit):
assert isinstance(hpi_gofs, np.ndarray) and hpi_gofs.ndim == 1
use_mask = hpi_gofs >= gof_limit
n_use = use_mask.sum()
if n_use < 3:
gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs)
raise RuntimeError(
f"Only {n_use} coil{_pl(n_use)} with goodness of fit >= {gof_limit}"
f", need at least 3 to refit HPI order (got {gofs})."
)
return use_mask


def _sorted_hpi_dig(dig, *, kinds=(FIFF.FIFFV_POINT_HPI,)):
return sorted(
# need .get here because the hpi_result["dig_points"] does not set it
Expand Down
4 changes: 2 additions & 2 deletions mne/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# update the checksum in the MNE_DATASETS dict below, and change version
# here: ↓↓↓↓↓↓↓↓
RELEASES = dict(
testing="0.169",
testing="0.170",
misc="0.27",
phantom_kit="0.2",
ucl_opm_auditory="0.2",
Expand Down Expand Up @@ -115,7 +115,7 @@
# Testing and misc are at the top as they're updated most often
MNE_DATASETS["testing"] = dict(
archive_name=f"{TESTING_VERSIONED}.tar.gz",
hash="md5:bb0524db8605e96fde6333893a969766",
hash="md5:ebd873ea89507cf5a75043f56119d22b",
url=(
"https://codeload.github.com/mne-tools/mne-testing-data/"
f"tar.gz/{RELEASES['testing']}"
Expand Down
17 changes: 17 additions & 0 deletions mne/tests/test_chpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds"
ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos"
chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif"
chpi_bad_gof_fname = data_path / "SSS" / "chpi_bad_gof-info.fif"

art_fname = (
data_path
Expand Down Expand Up @@ -1011,3 +1012,19 @@ def test_refit_hpi_locs_problematic():
)
assert 3 < ang < 6
assert 82 < dist < 87


@testing.requires_testing_data
def test_refit_hpi_locs_bad_gof():
"""Test that we can handle bad GOF HPI fits."""
# gh-13524
info = read_info(chpi_bad_gof_fname)
assert_array_equal(info["hpi_results"][-1]["used"], [2, 3, 4])
info_new = refit_hpi(info.copy(), amplitudes=False, locs=False)
assert_array_equal(info_new["hpi_results"][-1]["used"], [1, 2, 3, 4])
assert_trans_allclose(
info["dev_head_t"],
info_new["dev_head_t"],
dist_tol=1e-3,
angle_tol=1,
)