diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index cc660fe4986..b5460195f46 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -553,6 +553,8 @@ def plot_alignment( interaction="terrain", sensor_colors=None, *, + hpi_colors="auto", + hpi_labels=False, sensor_scales=None, verbose=None, ): @@ -644,7 +646,18 @@ def plot_alignment( %(sensor_colors)s .. versionchanged:: 1.6 - Support for passing a ``dict`` was added. + Support for passing a ``dict`` was added. + hpi_colors : 'auto' | list | dict + Colors for HPI coils when ``dig=True``. + ``'auto'`` (default): use official MEGIN/Elekta cable colors + (1=red, 2=blue, 3=green, 4=yellow, 5=magenta, 6=cyan). + Can also be a list of colors or ``{ident: color}`` dict. + + .. versionadded:: 1.11 + hpi_labels : bool + If ``True``, show the HPI coil ident number as 3D text above each coil. + + .. versionadded:: 1.11 %(sensor_scales)s .. versionadded:: 1.9 @@ -900,7 +913,9 @@ def plot_alignment( _check_option("dig", dig, (True, False, "fiducials")) if dig: if dig is True: - _plot_hpi_coils(renderer, info, to_cf_t) + _plot_hpi_coils( + renderer, info, to_cf_t, hpi_colors=hpi_colors, hpi_labels=hpi_labels + ) _plot_head_shape_points(renderer, info, to_cf_t) _plot_head_fiducials(renderer, info, to_cf_t, fid_colors) @@ -1292,34 +1307,90 @@ def _plot_hpi_coils( surf=None, check_inside=None, nearest=None, + hpi_colors="auto", + hpi_labels=False, ): + from matplotlib.colors import to_rgba + defaults = DEFAULTS["coreg"] scale = defaults["hpi_scale"] if scale is None else scale - hpi_loc = np.array( - [ - d["r"] - for d in (info["dig"] or []) - if ( - d["kind"] == FIFF.FIFFV_POINT_HPI - and d["coord_frame"] == FIFF.FIFFV_COORD_HEAD - ) + + hpi_digs = [ + d + for d in (info["dig"] or []) + if ( + d["kind"] == FIFF.FIFFV_POINT_HPI + and d["coord_frame"] == FIFF.FIFFV_COORD_HEAD + ) + ] + if not hpi_digs: + return [] + + hpi_idents = [d["ident"] for d in hpi_digs] + hpi_locs = apply_trans(to_cf_t["head"], [d["r"] for d in hpi_digs]) + + if hpi_colors == "auto": + # MEGIN/Elekta HPI coil cable colors(MNE community convention from user reports) + # 1 = red, 2 = blue, 3 = green, 4 = yellow, 5 = magenta, 6 = cyan + # Coil 1 is confirmed as "red" in Elekta TRIUX manual + # Full mapping is standard practice in MNE; no official 6-color list. + megin_colors = { + 1: "red", + 2: "blue", + 3: "green", + 4: "yellow", + 5: "magenta", + 6: "cyan", + } + colors = [ + megin_colors.get(ident, defaults["hpi_color"]) for ident in hpi_idents ] - ) - hpi_loc = apply_trans(to_cf_t["head"], hpi_loc) - actor, _ = _plot_glyphs( - renderer=renderer, - loc=hpi_loc, - color=defaults["hpi_color"], - scale=scale, - opacity=opacity, - orient_glyphs=orient_glyphs, - scale_by_distance=scale_by_distance, - surf=surf, - backface_culling=True, - check_inside=check_inside, - nearest=nearest, - ) - return actor + elif isinstance(hpi_colors, dict): + colors = [hpi_colors.get(ident, defaults["hpi_color"]) for ident in hpi_idents] + elif isinstance(hpi_colors, (list, tuple)): + if len(hpi_colors) != len(hpi_digs): + raise ValueError( + f"""hpi_colors list length + {len(hpi_colors)} != number of HPI coils {len(hpi_digs)} + """ + ) + colors = hpi_colors + else: + colors = [hpi_colors] * len(hpi_digs) + + actors = [] + + for loc, color, ident in zip(hpi_locs, colors, hpi_idents): + color_rgba = to_rgba(color) + + result = _plot_glyphs( + renderer=renderer, + loc=np.array([loc]), + color=color_rgba, + scale=scale, + opacity=opacity, + orient_glyphs=orient_glyphs, + scale_by_distance=scale_by_distance, + surf=surf, + backface_culling=True, + check_inside=check_inside, + nearest=nearest, + ) + + actors.append(result) + + if hpi_labels: + offset = np.array([0, 0, scale * 1.3]) + renderer.text3d( + x=loc[0], + y=loc[1], + z=loc[2] + offset[2], + text=str(ident), + scale=scale * 0.7, + color=color_rgba, + ) + + return actors def _get_nearest(nearest, check_inside, project_to_trans, proj_rr): diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index ab24e6a70db..d6c87a35de1 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -899,6 +899,45 @@ def test_plot_alignment_basic(tmp_path, renderer, mixed_fwd_cov_evoked): ) +@testing.requires_testing_data +def test_plot_alignment_hpi_colors_and_labels(renderer): + """Test hpi_colors and hpi_labels parameters.""" + info = read_info(evoked_fname) + fig = plot_alignment( + info=info, + dig=True, + surfaces=[], + coord_frame="head", + meg=[], + eeg=[], + ecog=False, + seeg=False, + fnirs=False, + dbs=False, + show_axes=False, + hpi_colors="auto", + hpi_labels=False, + ) + _assert_n_actors(fig, renderer, 7) + + fig = plot_alignment( + info=info, + dig=True, + surfaces=[], + coord_frame="head", + meg=[], + eeg=[], + ecog=False, + seeg=False, + fnirs=False, + dbs=False, + show_axes=False, + hpi_colors="auto", + hpi_labels=True, + ) + _assert_n_actors(fig, renderer, 11) + + @testing.requires_testing_data def test_plot_alignment_fnirs(renderer, tmp_path): """Test fNIRS plotting."""