diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index ec8b3ca107..2d5b17f1d3 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -11,12 +11,14 @@ import warnings import numpy as np +from collections import namedtuple -from .sortinganalyzer import AnalyzerExtension, register_result_extension +from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates from .sorting_tools import random_spikes_selection +from .job_tools import fix_job_kwargs, split_job_kwargs class ComputeRandomSpikes(AnalyzerExtension): @@ -752,8 +754,6 @@ class ComputeNoiseLevels(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object **kwargs : dict Additional parameters for the `spikeinterface.get_noise_levels()` function @@ -770,9 +770,6 @@ class ComputeNoiseLevels(AnalyzerExtension): need_job_kwargs = True need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, **noise_level_params): params = noise_level_params.copy() return params @@ -814,3 +811,147 @@ def _handle_backward_compatibility_on_load(self): register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() + + +class BaseSpikeVectorExtension(AnalyzerExtension): + """ + Base class for spikevector-based extension, where the data is a numpy array with the same + length as the spike vector. + """ + + extension_name = None # to be defined in subclass + need_recording = True + use_nodepipeline = True + need_job_kwargs = True + need_backward_compatibility_on_load = False + nodepipeline_variables = [] # to be defined in subclass + + def _set_params(self, **kwargs): + params = kwargs.copy() + return params + + def _run(self, verbose=False, **job_kwargs): + from spikeinterface.core.node_pipeline import run_node_pipeline + + # TODO: should we save directly to npy in binary_folder format / or to zarr? + # if self.sorting_analyzer.format == "binary_folder": + # gather_mode = "npy" + # extension_folder = self.sorting_analyzer.folder / "extenstions" / self.extension_name + # gather_kwargs = {"folder": extension_folder} + gather_mode = "memory" + gather_kwargs = {} + + job_kwargs = fix_job_kwargs(job_kwargs) + nodes = self.get_pipeline_nodes() + data = run_node_pipeline( + self.sorting_analyzer.recording, + nodes, + job_kwargs=job_kwargs, + job_name=self.extension_name, + gather_mode=gather_mode, + gather_kwargs=gather_kwargs, + verbose=False, + ) + if isinstance(data, tuple): + # this logic enables extensions to optionally compute additional data based on params + assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected" + else: + data = (data,) + if len(self.nodepipeline_variables) > len(data): + data_names = self.nodepipeline_variables[: len(data)] + else: + data_names = self.nodepipeline_variables + for d, name in zip(data, data_names): + self.data[name] = d + + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True): + """ + Return extension data. If the extension computes more than one `nodepipeline_variables`, + the `return_data_name` is used to specify which one to return. + + Parameters + ---------- + outputs : "numpy" | "by_unit", default: "numpy" + How to return the data, by default "numpy" + concatenated : bool, default: False + Whether to concatenate the data across segments. + return_data_name : str | None, default: None + The name of the data to return. If None and multiple `nodepipeline_variables` are computed, + the first one is returned. + copy : bool, default: True + Whether to return a copy of the data (only for outputs="numpy") + + Returns + ------- + numpy.ndarray | dict + The + """ + from spikeinterface.core.sorting_tools import spike_vector_to_indices + + if len(self.nodepipeline_variables) == 1: + return_data_name = self.nodepipeline_variables[0] + else: + if return_data_name is None: + return_data_name = self.nodepipeline_variables[0] + else: + assert ( + return_data_name in self.nodepipeline_variables + ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" + + all_data = self.data[return_data_name] + if outputs == "numpy": + if copy: + return all_data.copy() # return a copy to avoid modification + else: + return all_data + elif outputs == "by_unit": + unit_ids = self.sorting_analyzer.unit_ids + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) + data_by_units = {} + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + data_by_units[segment_index] = {} + for unit_id in unit_ids: + inds = spike_indices[segment_index][unit_id] + data_by_units[segment_index][unit_id] = all_data[inds] + + if concatenated: + data_by_units_concatenated = { + unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()]) + for unit_id in unit_ids + } + return data_by_units_concatenated + + return data_by_units + else: + raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") + + def _select_extension_data(self, unit_ids): + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) + + spikes = self.sorting_analyzer.sorting.to_spike_vector() + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + new_data = dict() + for data_name in self.nodepipeline_variables: + if self.data.get(data_name) is not None: + new_data[data_name] = self.data[data_name][keep_spike_mask] + + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + new_data = dict() + for data_name in self.nodepipeline_variables: + if self.data.get(data_name) is not None: + if keep_mask is None: + new_data[data_name] = self.data[data_name].copy() + else: + new_data[data_name] = self.data[data_name][keep_mask] + + return new_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1cec886d95..71654a67b4 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -317,6 +317,7 @@ def __init__( self.ms_after = ms_after self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0) self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0) + self.neighbours_mask = None class ExtractDenseWaveforms(WaveformsNode): @@ -356,8 +357,6 @@ def __init__( ms_after=ms_after, return_output=return_output, ) - # this is a bad hack to differentiate in the child if the parents is dense or not. - self.neighbours_mask = None def get_trace_margin(self): return max(self.nbefore, self.nafter) @@ -573,7 +572,7 @@ def run_node_pipeline( gather_mode : "memory" | "npy" How to gather the output of the nodes. gather_kwargs : dict - OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. + Options to control the "gather engine". See GatherToMemory or GatherToNpy. squeeze_output : bool, default True If only one output node then squeeze the tuple folder : str | Path | None @@ -784,7 +783,7 @@ def finalize_buffers(self, squeeze_output=False): class GatherToNpy: """ - Gather output of nodes into npy file and then open then as memmap. + Gather output of nodes into npy file and then open them as memmap. The trick is: @@ -891,6 +890,6 @@ def finalize_buffers(self, squeeze_output=False): return np.load(filename, mmap_mode="r") -class GatherToHdf5: +class GatherToZarr: pass # Fot me (sam) this is not necessary unless someone realy really want to use diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index ce8194f530..8f3ffe0617 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,18 +3,14 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs +from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type - -from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore - - -class ComputeAmplitudeScalings(AnalyzerExtension): +class ComputeAmplitudeScalings(BaseSpikeVectorExtension): """ Computes the amplitude scalings from a SortingAnalyzer. @@ -55,31 +51,11 @@ class ComputeAmplitudeScalings(AnalyzerExtension): multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. delta_collision_ms: float, default: 2 The maximum time difference in ms before and after a spike to gather colliding spikes. - load_if_exists : bool, default: False - Whether to load precomputed spike amplitudes, if they already exist. - outputs: "concatenated" | "by_unit", default: "concatenated" - How the output should be returned - {} - - Returns - ------- - amplitude_scalings: np.array or list of dict - The amplitude scalings. - - If "concatenated" all amplitudes for all spikes and all units are concatenated - - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) """ extension_name = "amplitude_scalings" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["amplitude_scalings", "collision_mask"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - - self.collisions = None def _set_params( self, @@ -90,7 +66,7 @@ def _set_params( handle_collisions=True, delta_collision_ms=2, ): - params = dict( + return super()._set_params( sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, @@ -98,38 +74,6 @@ def _set_params( handle_collisions=handle_collisions, delta_collision_ms=delta_collision_ms, ) - return params - - def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - - spikes = self.sorting_analyzer.sorting.to_spike_vector() - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - new_data = dict() - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask] - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - new_data = dict() - - if keep_mask is None: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy() - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"].copy() - else: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask] - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_mask] - - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - return self.data.copy() def _get_pipeline_nodes(self): @@ -141,6 +85,7 @@ def _get_pipeline_nodes(self): all_templates = get_dense_templates_array(self.sorting_analyzer, return_in_uV=return_in_uV) nbefore = _get_nbefore(self.sorting_analyzer) nafter = all_templates.shape[1] - nbefore + templates_ext = self.sorting_analyzer.get_extension("templates") # if ms_before / ms_after are set in params then the original templates are shorten if self.params["ms_before"] is not None: @@ -155,7 +100,7 @@ def _get_pipeline_nodes(self): cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) assert ( cut_out_after <= nafter - ), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}" + ), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}" else: cut_out_after = nafter @@ -210,30 +155,6 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, amplitude_scalings_node] return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - amp_scalings, collision_mask = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="amplitude_scalings", - gather_mode="memory", - verbose=verbose, - ) - self.data["amplitude_scalings"] = amp_scalings - if self.params["handle_collisions"]: - self.data["collision_mask"] = collision_mask - # TODO: make collisions "global" - # for collision in collisions: - # collisions_dict.update(collision) - # self.collisions = collisions_dict - # # Note: collisions are note in _extension_data because they are not pickable. We only store the indices - # self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - - def _get_data(self): - return self.data[f"amplitude_scalings"] - register_result_extension(ComputeAmplitudeScalings) compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory() diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index fee5cb4c6f..ce3d1cd4a9 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -36,8 +36,6 @@ class ComputeCorrelograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer_or_sorting : SortingAnalyzer | Sorting - A SortingAnalyzer or Sorting object window_ms : float, default: 50.0 The window around the spike to compute the correlation in ms. For example, if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. @@ -90,9 +88,6 @@ class ComputeCorrelograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -669,9 +664,6 @@ class ComputeACG3D(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params( self, window_ms: float = 50.0, diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index bb571b9326..a4111472c2 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -19,8 +19,6 @@ class ComputeISIHistograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object window_ms : float, default: 50 The window in ms bin_ms : float, default: 1 @@ -42,9 +40,6 @@ class ComputeISIHistograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f5c1a74848..8c79e17e42 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -27,8 +27,6 @@ class ComputePrincipalComponents(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object n_components : int, default: 5 Number of components fo PCA mode : "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" @@ -71,9 +69,6 @@ class ComputePrincipalComponents(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params( self, n_components=5, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 959103d922..993d1a105d 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -2,18 +2,14 @@ import numpy as np -from spikeinterface.core.job_tools import fix_job_kwargs - +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift - -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type -from spikeinterface.core.sorting_tools import spike_vector_to_indices +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type -class ComputeSpikeAmplitudes(AnalyzerExtension): +class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): """ - AnalyzerExtension Computes the spike amplitudes. Needs "templates" to be computed first. @@ -21,63 +17,18 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute extremum channel used to retrieve spike amplitudes. - - Returns - ------- - spike_amplitudes: np.array - All amplitudes for all spikes and all units are concatenated (along time, like in spike vector) - """ extension_name = "spike_amplitudes" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["amplitudes"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - - self._all_spikes = None def _set_params(self, peak_sign="neg"): - params = dict(peak_sign=peak_sign) - return params - - def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - - spikes = self.sorting_analyzer.sorting.to_spike_vector() - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - new_data = dict() - new_data["amplitudes"] = self.data["amplitudes"][keep_spike_mask] - - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - new_data = dict() - - if keep_mask is None: - new_data["amplitudes"] = self.data["amplitudes"].copy() - else: - new_data["amplitudes"] = self.data["amplitudes"][keep_mask] - - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # splitting only changes random spikes assignments - return self.data.copy() + return super()._set_params(peak_sign=peak_sign) def _get_pipeline_nodes(self): - recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -102,50 +53,8 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, spike_amplitudes_node] return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - amps = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="spike_amplitudes", - gather_mode="memory", - verbose=False, - ) - self.data["amplitudes"] = amps - - def _get_data(self, outputs="numpy", concatenated=False): - all_amplitudes = self.data["amplitudes"] - if outputs == "numpy": - return all_amplitudes - elif outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - amplitudes_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - amplitudes_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - amplitudes_by_units[segment_index][unit_id] = all_amplitudes[inds] - - if concatenated: - amplitudes_by_units_concatenated = { - unit_id: np.concatenate( - [amps_in_segment[unit_id] for amps_in_segment in amplitudes_by_units.values()] - ) - for unit_id in unit_ids - } - return amplitudes_by_units_concatenated - - return amplitudes_by_units - else: - raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") - register_result_extension(ComputeSpikeAmplitudes) - compute_spike_amplitudes = ComputeSpikeAmplitudes.function_factory() diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d7c7045f5a..a43f2bb93e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -2,21 +2,20 @@ import numpy as np -from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.core.job_tools import _shared_job_kwargs_doc +from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline +from spikeinterface.core.node_pipeline import SpikeRetriever +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -class ComputeSpikeLocations(AnalyzerExtension): +class ComputeSpikeLocations(BaseSpikeVectorExtension): """ Localize spikes in 2D or 3D with several methods given the template. Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 @@ -37,9 +36,6 @@ class ComputeSpikeLocations(AnalyzerExtension): The localization method to use method_kwargs : dict, default: dict() Other kwargs depending on the method. - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format - {} Returns ------- @@ -49,13 +45,7 @@ class ComputeSpikeLocations(AnalyzerExtension): extension_name = "spike_locations" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["spike_locations"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params( self, @@ -72,40 +62,13 @@ def _set_params( ) if spike_retriver_kwargs is not None: spike_retriver_kwargs_.update(spike_retriver_kwargs) - params = dict( + return super()._set_params( ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs_, method=method, method_kwargs=method_kwargs, ) - return params - - def _select_extension_data(self, unit_ids): - old_unit_ids = self.sorting_analyzer.unit_ids - unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spikes = self.sorting_analyzer.sorting.to_spike_vector() - - spike_mask = np.isin(spikes["unit_index"], unit_inds) - new_spike_locations = self.data["spike_locations"][spike_mask] - return dict(spike_locations=new_spike_locations) - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - - if keep_mask is None: - new_spike_locations = self.data["spike_locations"].copy() - else: - new_spike_locations = self.data["spike_locations"][keep_mask] - - ### In theory here, we should recompute the locations since the peak positions - ### in a merged could be different. Should be discussed - return dict(spike_locations=new_spike_locations) - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # splitting only changes random spikes assignments - return self.data.copy() def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes @@ -133,49 +96,6 @@ def _get_pipeline_nodes(self): ) return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - spike_locations = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="spike_locations", - gather_mode="memory", - verbose=verbose, - ) - self.data["spike_locations"] = spike_locations - - def _get_data(self, outputs="numpy", concatenated=False): - all_spike_locations = self.data["spike_locations"] - if outputs == "numpy": - return all_spike_locations - elif outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - spike_locations_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - spike_locations_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - spike_locations_by_units[segment_index][unit_id] = all_spike_locations[inds] - - if concatenated: - locations_by_units_concatenated = { - unit_id: np.concatenate( - [locs_in_segment[unit_id] for locs_in_segment in spike_locations_by_units.values()] - ) - for unit_id in unit_ids - } - return locations_by_units_concatenated - - return spike_locations_by_units - else: - raise ValueError(f"Wrong .get_data(outputs={outputs})") - - -ComputeSpikeLocations.__doc__.format(_shared_job_kwargs_doc) register_result_extension(ComputeSpikeLocations) compute_spike_locations = ComputeSpikeLocations.function_factory() diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index b6f054552d..328de2afce 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -51,9 +51,6 @@ class ComputeTemplateSimilarity(AnalyzerExtension): need_job_kwargs = False need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _handle_backward_compatibility_on_load(self): if "max_lag_ms" not in self.params: # make compatible analyzer created between february 24 and july 24 diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index ea297f7b6c..930c9e5438 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -24,8 +24,6 @@ class ComputeUnitLocations(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} @@ -44,9 +42,6 @@ class ComputeUnitLocations(AnalyzerExtension): need_job_kwargs = False need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _handle_backward_compatibility_on_load(self): if "method_kwargs" in self.params: # make compatible analyzer created between february 24 and july 24 diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 5d338a990b..518ee4ed10 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -31,8 +31,6 @@ class ComputeQualityMetrics(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. metric_params : dict of dicts or None