Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 147 additions & 6 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

import warnings
import numpy as np
from collections import namedtuple

from .sortinganalyzer import AnalyzerExtension, register_result_extension
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
from .recording_tools import get_noise_levels
from .template import Templates
from .sorting_tools import random_spikes_selection
from .job_tools import fix_job_kwargs, split_job_kwargs


class ComputeRandomSpikes(AnalyzerExtension):
Expand Down Expand Up @@ -752,8 +754,6 @@ class ComputeNoiseLevels(AnalyzerExtension):

Parameters
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object
**kwargs : dict
Additional parameters for the `spikeinterface.get_noise_levels()` function

Expand All @@ -770,9 +770,6 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_job_kwargs = True
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params
Expand Down Expand Up @@ -814,3 +811,147 @@ def _handle_backward_compatibility_on_load(self):

register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()


class BaseSpikeVectorExtension(AnalyzerExtension):
"""
Base class for spikevector-based extension, where the data is a numpy array with the same
length as the spike vector.
"""

extension_name = None # to be defined in subclass
need_recording = True
use_nodepipeline = True
need_job_kwargs = True
need_backward_compatibility_on_load = False
nodepipeline_variables = [] # to be defined in subclass

def _set_params(self, **kwargs):
params = kwargs.copy()
return params

def _run(self, verbose=False, **job_kwargs):
from spikeinterface.core.node_pipeline import run_node_pipeline

# TODO: should we save directly to npy in binary_folder format / or to zarr?
# if self.sorting_analyzer.format == "binary_folder":
# gather_mode = "npy"
# extension_folder = self.sorting_analyzer.folder / "extenstions" / self.extension_name
# gather_kwargs = {"folder": extension_folder}
gather_mode = "memory"
gather_kwargs = {}

job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
data = run_node_pipeline(
self.sorting_analyzer.recording,
nodes,
job_kwargs=job_kwargs,
job_name=self.extension_name,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
verbose=False,
)
if isinstance(data, tuple):
# this logic enables extensions to optionally compute additional data based on params
assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected"
else:
data = (data,)
if len(self.nodepipeline_variables) > len(data):
data_names = self.nodepipeline_variables[: len(data)]
else:
data_names = self.nodepipeline_variables
for d, name in zip(data, data_names):
self.data[name] = d

def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True):
"""
Return extension data. If the extension computes more than one `nodepipeline_variables`,
the `return_data_name` is used to specify which one to return.

Parameters
----------
outputs : "numpy" | "by_unit", default: "numpy"
How to return the data, by default "numpy"
concatenated : bool, default: False
Whether to concatenate the data across segments.
return_data_name : str | None, default: None
The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
the first one is returned.
copy : bool, default: True
Whether to return a copy of the data (only for outputs="numpy")

Returns
-------
numpy.ndarray | dict
The
"""
from spikeinterface.core.sorting_tools import spike_vector_to_indices

if len(self.nodepipeline_variables) == 1:
return_data_name = self.nodepipeline_variables[0]
else:
if return_data_name is None:
return_data_name = self.nodepipeline_variables[0]
else:
assert (
return_data_name in self.nodepipeline_variables
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"

all_data = self.data[return_data_name]
if outputs == "numpy":
if copy:
return all_data.copy() # return a copy to avoid modification
else:
return all_data
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
data_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
data_by_units[segment_index] = {}
for unit_id in unit_ids:
inds = spike_indices[segment_index][unit_id]
data_by_units[segment_index][unit_id] = all_data[inds]

if concatenated:
data_by_units_concatenated = {
unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()])
for unit_id in unit_ids
}
return data_by_units_concatenated

return data_by_units
else:
raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`")

def _select_extension_data(self, unit_ids):
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))

spikes = self.sorting_analyzer.sorting.to_spike_vector()
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)

new_data = dict()
for data_name in self.nodepipeline_variables:
if self.data.get(data_name) is not None:
new_data[data_name] = self.data[data_name][keep_spike_mask]

return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()
for data_name in self.nodepipeline_variables:
if self.data.get(data_name) is not None:
if keep_mask is None:
new_data[data_name] = self.data[data_name].copy()
else:
new_data[data_name] = self.data[data_name][keep_mask]

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# splitting only changes random spikes assignments
return self.data.copy()
9 changes: 4 additions & 5 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def __init__(
self.ms_after = ms_after
self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0)
self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0)
self.neighbours_mask = None


class ExtractDenseWaveforms(WaveformsNode):
Expand Down Expand Up @@ -356,8 +357,6 @@ def __init__(
ms_after=ms_after,
return_output=return_output,
)
# this is a bad hack to differentiate in the child if the parents is dense or not.
self.neighbours_mask = None
Comment on lines -359 to -360
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment was good to keep I think


def get_trace_margin(self):
return max(self.nbefore, self.nafter)
Expand Down Expand Up @@ -573,7 +572,7 @@ def run_node_pipeline(
gather_mode : "memory" | "npy"
How to gather the output of the nodes.
gather_kwargs : dict
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
Options to control the "gather engine". See GatherToMemory or GatherToNpy.
squeeze_output : bool, default True
If only one output node then squeeze the tuple
folder : str | Path | None
Expand Down Expand Up @@ -784,7 +783,7 @@ def finalize_buffers(self, squeeze_output=False):

class GatherToNpy:
"""
Gather output of nodes into npy file and then open then as memmap.
Gather output of nodes into npy file and then open them as memmap.


The trick is:
Expand Down Expand Up @@ -891,6 +890,6 @@ def finalize_buffers(self, squeeze_output=False):
return np.load(filename, mmap_mode="r")


class GatherToHdf5:
class GatherToZarr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amazing!

pass
# Fot me (sam) this is not necessary unless someone realy really want to use
95 changes: 8 additions & 87 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import numpy as np

from spikeinterface.core import ChannelSparsity
from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs
from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore
from spikeinterface.core.sortinganalyzer import register_result_extension
from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension

from spikeinterface.core.template_tools import get_template_extremum_channel
from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type

from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension

from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type

from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore


class ComputeAmplitudeScalings(AnalyzerExtension):
class ComputeAmplitudeScalings(BaseSpikeVectorExtension):
"""
Computes the amplitude scalings from a SortingAnalyzer.

Expand Down Expand Up @@ -55,31 +51,11 @@ class ComputeAmplitudeScalings(AnalyzerExtension):
multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently.
delta_collision_ms: float, default: 2
The maximum time difference in ms before and after a spike to gather colliding spikes.
load_if_exists : bool, default: False
Whether to load precomputed spike amplitudes, if they already exist.
outputs: "concatenated" | "by_unit", default: "concatenated"
How the output should be returned
{}

Returns
-------
amplitude_scalings: np.array or list of dict
The amplitude scalings.
- If "concatenated" all amplitudes for all spikes and all units are concatenated
- If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units)
"""

extension_name = "amplitude_scalings"
depend_on = ["templates"]
need_recording = True
use_nodepipeline = True
nodepipeline_variables = ["amplitude_scalings", "collision_mask"]
need_job_kwargs = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

self.collisions = None

def _set_params(
self,
Expand All @@ -90,46 +66,14 @@ def _set_params(
handle_collisions=True,
delta_collision_ms=2,
):
params = dict(
return super()._set_params(
sparsity=sparsity,
max_dense_channels=max_dense_channels,
ms_before=ms_before,
ms_after=ms_after,
handle_collisions=handle_collisions,
delta_collision_ms=delta_collision_ms,
)
return params

def _select_extension_data(self, unit_ids):
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))

spikes = self.sorting_analyzer.sorting.to_spike_vector()
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)

new_data = dict()
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask]
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask]
return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()

if keep_mask is None:
new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy()
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"].copy()
else:
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask]
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"][keep_mask]

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
return self.data.copy()

def _get_pipeline_nodes(self):

Expand All @@ -141,6 +85,7 @@ def _get_pipeline_nodes(self):
all_templates = get_dense_templates_array(self.sorting_analyzer, return_in_uV=return_in_uV)
nbefore = _get_nbefore(self.sorting_analyzer)
nafter = all_templates.shape[1] - nbefore
templates_ext = self.sorting_analyzer.get_extension("templates")

# if ms_before / ms_after are set in params then the original templates are shorten
if self.params["ms_before"] is not None:
Expand All @@ -155,7 +100,7 @@ def _get_pipeline_nodes(self):
cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
assert (
cut_out_after <= nafter
), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}"
), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}"
else:
cut_out_after = nafter

Expand Down Expand Up @@ -210,30 +155,6 @@ def _get_pipeline_nodes(self):
nodes = [spike_retriever_node, amplitude_scalings_node]
return nodes

def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
amp_scalings, collision_mask = run_node_pipeline(
self.sorting_analyzer.recording,
nodes,
job_kwargs=job_kwargs,
job_name="amplitude_scalings",
gather_mode="memory",
verbose=verbose,
)
self.data["amplitude_scalings"] = amp_scalings
if self.params["handle_collisions"]:
self.data["collision_mask"] = collision_mask
# TODO: make collisions "global"
# for collision in collisions:
# collisions_dict.update(collision)
# self.collisions = collisions_dict
# # Note: collisions are note in _extension_data because they are not pickable. We only store the indices
# self._extension_data["collisions"] = np.array(list(collisions_dict.keys()))

def _get_data(self):
return self.data[f"amplitude_scalings"]


register_result_extension(ComputeAmplitudeScalings)
compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory()
Expand Down
Loading