Skip to content

Commit 3abb033

Browse files
authored
Merge pull request #4194 from samuelgarcia/estimate_template_sparse
Add sparsity_mask option in estimate_templates_with_accumulator()
2 parents f246b6a + d735936 commit 3abb033

File tree

2 files changed

+66
-15
lines changed

2 files changed

+66
-15
lines changed

src/spikeinterface/core/tests/test_waveform_tools.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -227,16 +227,36 @@ def test_estimate_templates():
227227

228228
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
229229

230+
# mask with differents sparsity
231+
sparsity_mask = np.ones((sorting.unit_ids.size, recording.channel_ids.size), dtype=bool)
232+
sparsity_mask[:4, : recording.channel_ids.size // 2 - 1] = False
233+
sparsity_mask[4:, recording.channel_ids.size // 2 :] = False
234+
230235
for operator in ("average", "median"):
231-
templates = estimate_templates(
236+
templates_array = estimate_templates(
232237
recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator, return_in_uV=True, **job_kwargs
233238
)
234239
# print(templates.shape)
235-
assert templates.shape[0] == sorting.unit_ids.size
236-
assert templates.shape[1] == nbefore + nafter
237-
assert templates.shape[2] == recording.get_num_channels()
238-
239-
assert np.any(templates != 0)
240+
assert templates_array.shape[0] == sorting.unit_ids.size
241+
assert templates_array.shape[1] == nbefore + nafter
242+
assert templates_array.shape[2] == recording.get_num_channels()
243+
244+
assert np.any(templates_array != 0)
245+
246+
sparse_templates_array = estimate_templates(
247+
recording,
248+
spikes,
249+
sorting.unit_ids,
250+
nbefore,
251+
nafter,
252+
operator=operator,
253+
return_in_uV=True,
254+
sparsity_mask=sparsity_mask,
255+
**job_kwargs,
256+
)
257+
n_chan = np.max(np.sum(sparsity_mask, axis=1))
258+
assert n_chan == sparse_templates_array.shape[2]
259+
assert np.any(sparse_templates_array == 0)
240260

241261
# import matplotlib.pyplot as plt
242262
# fig, ax = plt.subplots()
@@ -247,7 +267,7 @@ def test_estimate_templates():
247267

248268

249269
if __name__ == "__main__":
250-
cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core"
251-
test_waveform_tools(cache_folder)
252-
test_estimate_templates_with_accumulator()
270+
# cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core"
271+
# test_waveform_tools(cache_folder)
272+
# test_estimate_templates_with_accumulator()
253273
test_estimate_templates()

src/spikeinterface/core/waveform_tools.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -744,12 +744,13 @@ def estimate_templates(
744744
operator: str = "average",
745745
return_scaled=None,
746746
return_in_uV=True,
747+
sparsity_mask=None,
747748
job_name=None,
748749
**job_kwargs,
749750
):
750751
"""
751752
Estimate dense templates with "average" or "median".
752-
If "average" internally estimate_templates_with_accumulator() is used to saved memory/
753+
If "average" internally estimate_templates_with_accumulator() is used to saved memory.
753754
754755
Parameters
755756
----------
@@ -770,6 +771,8 @@ def estimate_templates(
770771
return_in_uV : bool, default: True
771772
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
772773
traces are scaled to uV
774+
sparsity_mask: None or array of bool, default: None
775+
If not None shape must be must be (len(unit_ids), len(channel_ids))
773776
774777
Returns
775778
-------
@@ -791,7 +794,15 @@ def estimate_templates(
791794

792795
if operator == "average":
793796
templates_array = estimate_templates_with_accumulator(
794-
recording, spikes, unit_ids, nbefore, nafter, return_in_uV=return_in_uV, job_name=job_name, **job_kwargs
797+
recording,
798+
spikes,
799+
unit_ids,
800+
nbefore,
801+
nafter,
802+
return_in_uV=return_in_uV,
803+
sparsity_mask=sparsity_mask,
804+
job_name=job_name,
805+
**job_kwargs,
795806
)
796807
elif operator == "median":
797808
all_waveforms, wf_array_info = extract_waveforms_to_single_buffer(
@@ -802,6 +813,7 @@ def estimate_templates(
802813
nafter,
803814
mode="shared_memory",
804815
return_in_uV=return_in_uV,
816+
sparsity_mask=sparsity_mask,
805817
copy=False,
806818
**job_kwargs,
807819
)
@@ -828,6 +840,7 @@ def estimate_templates_with_accumulator(
828840
nafter: int,
829841
return_scaled=None,
830842
return_in_uV=True,
843+
sparsity_mask=None,
831844
job_name=None,
832845
return_std: bool = False,
833846
verbose: bool = False,
@@ -859,6 +872,8 @@ def estimate_templates_with_accumulator(
859872
return_in_uV : bool, default: True
860873
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
861874
traces are scaled to uV
875+
sparsity_mask: None or array of bool, default: None
876+
If not None shape must be must be (len(unit_ids), len(channel_ids))
862877
return_std: bool, default: False
863878
If True, the standard deviation is also computed.
864879
@@ -882,10 +897,14 @@ def estimate_templates_with_accumulator(
882897
job_kwargs = fix_job_kwargs(job_kwargs)
883898
num_worker = job_kwargs["n_jobs"]
884899

885-
num_chans = recording.get_num_channels()
900+
if sparsity_mask is None:
901+
num_chans = int(recording.get_num_channels())
902+
else:
903+
num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int
886904
num_units = len(unit_ids)
887905

888906
shape = (num_worker, num_units, nbefore + nafter, num_chans)
907+
889908
dtype = np.dtype("float32")
890909
waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype)
891910
shm_name = shm.name
@@ -909,6 +928,7 @@ def estimate_templates_with_accumulator(
909928
nbefore,
910929
nafter,
911930
return_in_uV,
931+
sparsity_mask,
912932
)
913933

914934
if job_name is None:
@@ -965,13 +985,15 @@ def _init_worker_estimate_templates(
965985
nbefore,
966986
nafter,
967987
return_in_uV,
988+
sparsity_mask,
968989
):
969990
worker_dict = {}
970991
worker_dict["recording"] = recording
971992
worker_dict["spikes"] = spikes
972993
worker_dict["nbefore"] = nbefore
973994
worker_dict["nafter"] = nafter
974995
worker_dict["return_in_uV"] = return_in_uV
996+
worker_dict["sparsity_mask"] = sparsity_mask
975997

976998
from multiprocessing.shared_memory import SharedMemory
977999
import multiprocessing
@@ -1009,6 +1031,7 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10091031
waveform_squared_accumulator_per_worker = worker_dict.get("waveform_squared_accumulator_per_worker", None)
10101032
worker_index = worker_dict["worker_index"]
10111033
return_in_uV = worker_dict["return_in_uV"]
1034+
sparsity_mask = worker_dict["sparsity_mask"]
10121035

10131036
seg_size = recording.get_num_samples(segment_index=segment_index)
10141037

@@ -1040,6 +1063,14 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10401063
unit_index = spikes[spike_index]["unit_index"]
10411064
wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :]
10421065

1043-
waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf
1044-
if waveform_squared_accumulator_per_worker is not None:
1045-
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2
1066+
if sparsity_mask is None:
1067+
waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf
1068+
if waveform_squared_accumulator_per_worker is not None:
1069+
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2
1070+
1071+
else:
1072+
mask = sparsity_mask[unit_index, :]
1073+
wf = wf[:, mask]
1074+
waveform_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf
1075+
if waveform_squared_accumulator_per_worker is not None:
1076+
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf**2

0 commit comments

Comments
 (0)