@@ -744,12 +744,13 @@ def estimate_templates(
744744 operator : str = "average" ,
745745 return_scaled = None ,
746746 return_in_uV = True ,
747+ sparsity_mask = None ,
747748 job_name = None ,
748749 ** job_kwargs ,
749750):
750751 """
751752 Estimate dense templates with "average" or "median".
752- If "average" internally estimate_templates_with_accumulator() is used to saved memory/
753+ If "average" internally estimate_templates_with_accumulator() is used to saved memory.
753754
754755 Parameters
755756 ----------
@@ -770,6 +771,8 @@ def estimate_templates(
770771 return_in_uV : bool, default: True
771772 If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
772773 traces are scaled to uV
774+ sparsity_mask: None or array of bool, default: None
775+ If not None shape must be must be (len(unit_ids), len(channel_ids))
773776
774777 Returns
775778 -------
@@ -791,7 +794,15 @@ def estimate_templates(
791794
792795 if operator == "average" :
793796 templates_array = estimate_templates_with_accumulator (
794- recording , spikes , unit_ids , nbefore , nafter , return_in_uV = return_in_uV , job_name = job_name , ** job_kwargs
797+ recording ,
798+ spikes ,
799+ unit_ids ,
800+ nbefore ,
801+ nafter ,
802+ return_in_uV = return_in_uV ,
803+ sparsity_mask = sparsity_mask ,
804+ job_name = job_name ,
805+ ** job_kwargs ,
795806 )
796807 elif operator == "median" :
797808 all_waveforms , wf_array_info = extract_waveforms_to_single_buffer (
@@ -802,6 +813,7 @@ def estimate_templates(
802813 nafter ,
803814 mode = "shared_memory" ,
804815 return_in_uV = return_in_uV ,
816+ sparsity_mask = sparsity_mask ,
805817 copy = False ,
806818 ** job_kwargs ,
807819 )
@@ -828,6 +840,7 @@ def estimate_templates_with_accumulator(
828840 nafter : int ,
829841 return_scaled = None ,
830842 return_in_uV = True ,
843+ sparsity_mask = None ,
831844 job_name = None ,
832845 return_std : bool = False ,
833846 verbose : bool = False ,
@@ -859,6 +872,8 @@ def estimate_templates_with_accumulator(
859872 return_in_uV : bool, default: True
860873 If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
861874 traces are scaled to uV
875+ sparsity_mask: None or array of bool, default: None
876+ If not None shape must be must be (len(unit_ids), len(channel_ids))
862877 return_std: bool, default: False
863878 If True, the standard deviation is also computed.
864879
@@ -882,10 +897,14 @@ def estimate_templates_with_accumulator(
882897 job_kwargs = fix_job_kwargs (job_kwargs )
883898 num_worker = job_kwargs ["n_jobs" ]
884899
885- num_chans = recording .get_num_channels ()
900+ if sparsity_mask is None :
901+ num_chans = int (recording .get_num_channels ())
902+ else :
903+ num_chans = int (max (np .sum (sparsity_mask , axis = 1 ))) # This is a numpy scalar, so we cast to int
886904 num_units = len (unit_ids )
887905
888906 shape = (num_worker , num_units , nbefore + nafter , num_chans )
907+
889908 dtype = np .dtype ("float32" )
890909 waveform_accumulator_per_worker , shm = make_shared_array (shape , dtype )
891910 shm_name = shm .name
@@ -909,6 +928,7 @@ def estimate_templates_with_accumulator(
909928 nbefore ,
910929 nafter ,
911930 return_in_uV ,
931+ sparsity_mask ,
912932 )
913933
914934 if job_name is None :
@@ -965,13 +985,15 @@ def _init_worker_estimate_templates(
965985 nbefore ,
966986 nafter ,
967987 return_in_uV ,
988+ sparsity_mask ,
968989):
969990 worker_dict = {}
970991 worker_dict ["recording" ] = recording
971992 worker_dict ["spikes" ] = spikes
972993 worker_dict ["nbefore" ] = nbefore
973994 worker_dict ["nafter" ] = nafter
974995 worker_dict ["return_in_uV" ] = return_in_uV
996+ worker_dict ["sparsity_mask" ] = sparsity_mask
975997
976998 from multiprocessing .shared_memory import SharedMemory
977999 import multiprocessing
@@ -1009,6 +1031,7 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10091031 waveform_squared_accumulator_per_worker = worker_dict .get ("waveform_squared_accumulator_per_worker" , None )
10101032 worker_index = worker_dict ["worker_index" ]
10111033 return_in_uV = worker_dict ["return_in_uV" ]
1034+ sparsity_mask = worker_dict ["sparsity_mask" ]
10121035
10131036 seg_size = recording .get_num_samples (segment_index = segment_index )
10141037
@@ -1040,6 +1063,14 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10401063 unit_index = spikes [spike_index ]["unit_index" ]
10411064 wf = traces [sample_index - start - nbefore : sample_index - start + nafter , :]
10421065
1043- waveform_accumulator_per_worker [worker_index , unit_index , :, :] += wf
1044- if waveform_squared_accumulator_per_worker is not None :
1045- waveform_squared_accumulator_per_worker [worker_index , unit_index , :, :] += wf ** 2
1066+ if sparsity_mask is None :
1067+ waveform_accumulator_per_worker [worker_index , unit_index , :, :] += wf
1068+ if waveform_squared_accumulator_per_worker is not None :
1069+ waveform_squared_accumulator_per_worker [worker_index , unit_index , :, :] += wf ** 2
1070+
1071+ else :
1072+ mask = sparsity_mask [unit_index , :]
1073+ wf = wf [:, mask ]
1074+ waveform_accumulator_per_worker [worker_index , unit_index , :, : wf .shape [1 ]] += wf
1075+ if waveform_squared_accumulator_per_worker is not None :
1076+ waveform_squared_accumulator_per_worker [worker_index , unit_index , :, : wf .shape [1 ]] += wf ** 2
0 commit comments