Skip to content

Commit f58b48c

Browse files
authored
Deal with batch_downsampling in kilosort4.1.2 (#4206)
1 parent c701463 commit f58b48c

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

.github/scripts/test_kilosort4_ci.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@
116116
PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20})
117117
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset")
118118

119+
if parse(kilosort.__version__) >= parse("4.1.2"):
120+
PARAMS_TO_TEST_DICT.update({"batch_downsampling": 2})
121+
PARAMETERS_NOT_AFFECTING_RESULTS.append("batch_downsampling")
122+
123+
PARAMS_TO_TEST_DICT.update({"cluster_init_seed": 2})
124+
PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_init_seed")
125+
119126

120127
PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())
121128

@@ -328,6 +335,8 @@ def test_binary_filtered_arguments(self):
328335
"scale",
329336
"file_object",
330337
]
338+
if parse(kilosort.__version__) >= parse("4.1.2"):
339+
expected_arguments += ["batch_downsampling"]
331340

332341
self._check_arguments(BinaryFiltered, expected_arguments)
333342

@@ -351,6 +360,12 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp
351360
"""
352361
recording, paths = recording_and_paths
353362
param_key = parameter
363+
364+
# Non-default batch_downsampling fails for short recordings, as there aren't
365+
# enough batches. Since we test on a 5s recording, we skip it.
366+
if param_key == "batch_downsampling":
367+
return
368+
354369
param_value = PARAMS_TO_TEST_DICT[param_key]
355370

356371
# Setup parameters for KS4 and run it natively

src/spikeinterface/sorters/external/kilosort4.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,29 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
307307
if version.parse(ks_version) >= version.parse("4.0.34"):
308308
ops = ops[0]
309309

310-
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
311-
get_run_parameters(ops)
312-
)
310+
(
311+
n_chan_bin,
312+
fs,
313+
NT,
314+
nt,
315+
twav_min,
316+
chan_map,
317+
dtype,
318+
do_CAR,
319+
invert,
320+
_,
321+
_,
322+
tmin,
323+
tmax,
324+
artifact,
325+
_,
326+
_,
327+
*possibly_batch_downsampling,
328+
) = get_run_parameters(ops)
329+
330+
batch_downsample_dict = {}
331+
if len(possibly_batch_downsampling) > 0:
332+
batch_downsample_dict["batch_downsampling"] = possibly_batch_downsampling[0]
313333

314334
# Set preprocessing and drift correction parameters
315335
if not params["skip_kilosort_preprocessing"]:
@@ -334,6 +354,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
334354
tmax=tmax,
335355
artifact_threshold=artifact,
336356
file_object=file_object,
357+
**batch_downsample_dict,
337358
)
338359
ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None)
339360
ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels()))

0 commit comments

Comments
 (0)