diff --git a/doc/changes/dev/13525.bugfix.rst b/doc/changes/dev/13525.bugfix.rst new file mode 100644 index 00000000000..8477178380a --- /dev/null +++ b/doc/changes/dev/13525.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.chpi.refit_hpi` did not take ``gof_limit`` into account when fitting HPI order, by `Eric Larson`_ diff --git a/doc/sphinxext/directive_formatting.py b/doc/sphinxext/directive_formatting.py index a3090ab4c90..4c65f653d4a 100644 --- a/doc/sphinxext/directive_formatting.py +++ b/doc/sphinxext/directive_formatting.py @@ -60,7 +60,7 @@ def check_directive_formatting(*args): # another directive/another directive's content) if idx == 0: continue - dir_pattern = r"\.\. [a-zA-Z]+::" + dir_pattern = r"^\s*\.\. \w+::" # line might start with whitespace head_pattern = r"^[-|=|\^]+$" directive = re.search(dir_pattern, line) if directive is not None: @@ -84,5 +84,5 @@ def check_directive_formatting(*args): if bad: sphinx_logger.warning( f"{source_type} '{name}' is missing a blank line before the " - f"directive '{directive.group()}'" + f"directive '{directive.group()}' on line {idx + 1}" ) diff --git a/examples/decoding/decoding_time_generalization_conditions.py b/examples/decoding/decoding_time_generalization_conditions.py index e71112e8375..cc9e62f06cf 100644 --- a/examples/decoding/decoding_time_generalization_conditions.py +++ b/examples/decoding/decoding_time_generalization_conditions.py @@ -6,10 +6,9 @@ ========================================================================= This example runs the analysis described in :footcite:`KingDehaene2014`. It -illustrates how one can -fit a linear classifier to identify a discriminatory topography at a given time -instant and subsequently assess whether this linear model can accurately -predict all of the time samples of a second set of conditions. +illustrates how one can fit a linear classifier to identify a discriminatory +topography at a given time instant and subsequently assess whether this linear +model can accurately predict all of the time samples of a second set of conditions. """ # Authors: Jean-Rémi King # Alexandre Gramfort diff --git a/mne/chpi.py b/mne/chpi.py index 711474338c9..cc921a9843e 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -579,27 +579,37 @@ def _chpi_objective(x, coil_dev_rrs, coil_head_rrs): return d.sum() -def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs): +def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs, *, quat=None): """Fit rotation and translation (quaternion) parameters for cHPI coils.""" denom = np.linalg.norm(coil_head_rrs - np.mean(coil_head_rrs, axis=0)) denom *= denom # We could try to solve it the analytic way: # TODO someday we could choose to weight these points by their goodness # of fit somehow, see also https://github.com/mne-tools/mne-python/issues/11330 - quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] + if quat is None: + quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom return quat, gof -def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix=""): +def _fit_coil_order_dev_head_trans( + dev_pnts, head_pnts, *, bias=True, gofs=None, gof_limit=0.98, prefix="" +): """Compute Device to Head transform allowing for permutiatons of points.""" + n_coils = len(dev_pnts) id_quat = np.zeros(6) - best_order = None + best_order = np.full(n_coils, -1, dtype=int) best_g = -999 best_quat = id_quat - for this_order in itertools.permutations(np.arange(len(head_pnts))): + assert dev_pnts.shape == head_pnts.shape == (n_coils, 3) + gofs = np.ones(n_coils) if gofs is None else gofs + use_mask = _gof_use_mask(gofs, gof_limit=gof_limit) + n_use = int(use_mask.sum()) # explicit int cast for itertools.permutations + dev_pnts_tmp = dev_pnts[use_mask] + # First pass: figure out best order using the good dev points + for this_order in itertools.permutations(np.arange(len(head_pnts)), n_use): head_pnts_tmp = head_pnts[np.array(this_order)] - this_quat, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp) + this_quat, g = _fit_chpi_quat(dev_pnts_tmp, head_pnts_tmp) assert np.linalg.det(quat_to_rot(this_quat[:3])) > 0.9999 if bias: # For symmetrical arrangements, flips can produce roughly @@ -612,17 +622,35 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix="") if check_g > best_g: out_g = g best_g = check_g - best_order = np.array(this_order) + best_order[use_mask] = this_order best_quat = this_quat + del this_order + # Second pass: now fit the remaining (bad) coils using the best order and quat + # from above + missing = np.setdiff1d(np.arange(n_coils), best_order[best_order >= 0]) + best_missing_g = -np.inf + for this_order in itertools.permutations(missing): + full_order = best_order.copy() + full_order[~use_mask] = this_order + assert (full_order >= 0).all() + assert np.array_equal(np.sort(full_order), np.arange(n_coils)) + head_pnts_tmp = head_pnts[np.array(full_order)] + _, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp, quat=best_quat) + if g > best_missing_g: + best_missing_g = g + best_order[:] = full_order + del this_order + assert np.array_equal(np.sort(best_order), np.arange(n_coils)) # Convert Quaterion to transform dev_head_t = _quat_to_affine(best_quat) ang, dist = angle_distance_between_rigid( dev_head_t, angle_units="deg", distance_units="mm" ) + extra = f" using {n_use}/{n_coils} coils" if n_use < n_coils else "" logger.info( f"{prefix}Fitted dev_head_t {ang:0.1f}° and {dist:0.1f} mm " - f"from device origin (GOF: {out_g:.3f})" + f"from device origin{extra} (GOF: {out_g:.3f})" ) return dev_head_t, best_order, out_g @@ -1703,7 +1731,8 @@ def refit_hpi( :func:`~mne.chpi.compute_chpi_locs`. 3. Optionally determine coil digitization order by testing all permutations for the best goodness of fit between digitized coil locations and - (rigid-transformed) fitted coil locations. + (rigid-transformed) fitted coil locations, choosing the order first based on + those that satisfy ``gof_limit`` then the others. 4. Subselect coils to use for fitting ``dev_head_t`` based on ``gof_limit``, ``dist_limit``, and ``use``. 5. Update info inplace by modifying ``info["dev_head_t"]`` and appending new entries @@ -1816,6 +1845,8 @@ def refit_hpi( fit_dev_head_t, fit_order, _g = _fit_coil_order_dev_head_trans( hpi_dev, hpi_head, + gofs=hpi_gofs, + gof_limit=gof_limit, prefix=" ", ) else: @@ -1824,27 +1855,21 @@ def refit_hpi( # 4. Subselect usable coils and determine final dev_head_t if isinstance(use, int) or use is None: - used = np.where(hpi_gofs >= gof_limit)[0] - if len(used) < 3: - gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs) - raise RuntimeError( - f"Only {len(used)} coil{_pl(used)} with goodness of fit >= {gof_limit}" - f", need at least 3 to refit HPI order (got {gofs})." - ) - quat, _g = _fit_chpi_quat(hpi_dev[used], hpi_head[fit_order][used]) + use_mask = _gof_use_mask(hpi_gofs, gof_limit=gof_limit) + quat, _g = _fit_chpi_quat(hpi_dev[use_mask], hpi_head[fit_order][use_mask]) fit_dev_head_t = _quat_to_affine(quat) hpi_head_got = apply_trans(fit_dev_head_t, hpi_dev) dists = np.linalg.norm(hpi_head_got - hpi_head[fit_order], axis=1) dist_str = " ".join(f"{dist * 1e3:.1f}" for dist in dists) logger.info(f" Coil distances after initial fit: {dist_str} mm") - good_dists_idx = np.where(dists[used] <= dist_limit)[0] + good_dists_idx = np.where(dists[use_mask] <= dist_limit)[0] if not len(good_dists_idx) >= 3: raise RuntimeError( - f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} have distance " + f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} with distance " f"<= {dist_limit * 1e3:.1f} mm, need at least 3 to refit HPI order " f"(got distances: {np.round(1e3 * dists, 1)})." ) - used = used[good_dists_idx] + used = np.where(use_mask)[0][good_dists_idx] if use is not None: used = np.sort(used[np.argsort(hpi_gofs[used])[-use:]]) else: @@ -1927,6 +1952,19 @@ def refit_hpi( return info +def _gof_use_mask(hpi_gofs, *, gof_limit): + assert isinstance(hpi_gofs, np.ndarray) and hpi_gofs.ndim == 1 + use_mask = hpi_gofs >= gof_limit + n_use = use_mask.sum() + if n_use < 3: + gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs) + raise RuntimeError( + f"Only {n_use} coil{_pl(n_use)} with goodness of fit >= {gof_limit}" + f", need at least 3 to refit HPI order (got {gofs})." + ) + return use_mask + + def _sorted_hpi_dig(dig, *, kinds=(FIFF.FIFFV_POINT_HPI,)): return sorted( # need .get here because the hpi_result["dig_points"] does not set it diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 23c1cf9e78b..ca65910dda6 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.169", + testing="0.170", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:bb0524db8605e96fde6333893a969766", + hash="md5:ebd873ea89507cf5a75043f56119d22b", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index 0ba13f8c708..ec0d9c3c70f 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -73,6 +73,7 @@ ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds" ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos" chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif" +chpi_bad_gof_fname = data_path / "SSS" / "chpi_bad_gof-info.fif" art_fname = ( data_path @@ -1011,3 +1012,19 @@ def test_refit_hpi_locs_problematic(): ) assert 3 < ang < 6 assert 82 < dist < 87 + + +@testing.requires_testing_data +def test_refit_hpi_locs_bad_gof(): + """Test that we can handle bad GOF HPI fits.""" + # gh-13524 + info = read_info(chpi_bad_gof_fname) + assert_array_equal(info["hpi_results"][-1]["used"], [2, 3, 4]) + info_new = refit_hpi(info.copy(), amplitudes=False, locs=False) + assert_array_equal(info_new["hpi_results"][-1]["used"], [1, 2, 3, 4]) + assert_trans_allclose( + info["dev_head_t"], + info_new["dev_head_t"], + dist_tol=1e-3, + angle_tol=1, + )