Skip to content

Commit d735936

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b083bb8 commit d735936

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

src/spikeinterface/core/tests/test_waveform_tools.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,8 @@ def test_estimate_templates():
229229

230230
# mask with differents sparsity
231231
sparsity_mask = np.ones((sorting.unit_ids.size, recording.channel_ids.size), dtype=bool)
232-
sparsity_mask[:4, :recording.channel_ids.size//2 -1 ] = False
233-
sparsity_mask[4:, recording.channel_ids.size//2:] = False
234-
232+
sparsity_mask[:4, : recording.channel_ids.size // 2 - 1] = False
233+
sparsity_mask[4:, recording.channel_ids.size // 2 :] = False
235234

236235
for operator in ("average", "median"):
237236
templates_array = estimate_templates(
@@ -245,8 +244,15 @@ def test_estimate_templates():
245244
assert np.any(templates_array != 0)
246245

247246
sparse_templates_array = estimate_templates(
248-
recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator,
249-
return_in_uV=True, sparsity_mask=sparsity_mask, **job_kwargs
247+
recording,
248+
spikes,
249+
sorting.unit_ids,
250+
nbefore,
251+
nafter,
252+
operator=operator,
253+
return_in_uV=True,
254+
sparsity_mask=sparsity_mask,
255+
**job_kwargs,
250256
)
251257
n_chan = np.max(np.sum(sparsity_mask, axis=1))
252258
assert n_chan == sparse_templates_array.shape[2]

src/spikeinterface/core/waveform_tools.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,15 @@ def estimate_templates(
794794

795795
if operator == "average":
796796
templates_array = estimate_templates_with_accumulator(
797-
recording, spikes, unit_ids, nbefore, nafter, return_in_uV=return_in_uV, sparsity_mask=sparsity_mask, job_name=job_name, **job_kwargs
797+
recording,
798+
spikes,
799+
unit_ids,
800+
nbefore,
801+
nafter,
802+
return_in_uV=return_in_uV,
803+
sparsity_mask=sparsity_mask,
804+
job_name=job_name,
805+
**job_kwargs,
798806
)
799807
elif operator == "median":
800808
all_waveforms, wf_array_info = extract_waveforms_to_single_buffer(
@@ -894,7 +902,7 @@ def estimate_templates_with_accumulator(
894902
else:
895903
num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int
896904
num_units = len(unit_ids)
897-
905+
898906
shape = (num_worker, num_units, nbefore + nafter, num_chans)
899907

900908
dtype = np.dtype("float32")
@@ -1060,11 +1068,9 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10601068
if waveform_squared_accumulator_per_worker is not None:
10611069
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2
10621070

1063-
10641071
else:
10651072
mask = sparsity_mask[unit_index, :]
10661073
wf = wf[:, mask]
10671074
waveform_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf
10681075
if waveform_squared_accumulator_per_worker is not None:
1069-
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :wf.shape[1]] += wf**2
1070-
1076+
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf**2

0 commit comments

Comments
 (0)