diff --git a/resources/CN-logo.jpg b/resources/CN-logo.jpg new file mode 100644 index 0000000..204a3d5 Binary files /dev/null and b/resources/CN-logo.jpg differ diff --git a/resources/generate_cambridgeneurotech_libray.py b/resources/generate_cambridgeneurotech_libray.py index 84fb039..39e9f25 100644 --- a/resources/generate_cambridgeneurotech_libray.py +++ b/resources/generate_cambridgeneurotech_libray.py @@ -1,339 +1,186 @@ -''' -2021-01-07 CambridgeNeurotech -Original script: - * contact: Tahl Holtzman - * email: info@cambridgeneurotech.com +""" +2025-12-16 CambridgeNeurotech -2021-03-01 -The script have been modified by Smauel Garcia (samuel.garcia@cnrs.fr): - * more pytonic - * improve code readability - * not more channel_device_index the order is the contact index - * simpler function for plotting. +Derive probes to be used with SpikeInterface base on Cambridgeneurotech database at: +https://github.com/cambridge-neurotech/probe_maps -2021-04-02 -Samuel Garcia: - * add "contact_id" one based in Probe. -2023-06-14 -generate new library - -2023-10-30 -Generate new library with some fixes - - -Derive probes to be used with SpikeInterface base on Cambridgeneurotech databases -Probe library to match and add on -https://gin.g-node.org/spikeinterface/probeinterface_library/src/master/cambridgeneurotech - -see repos https://github.com/SpikeInterface/probeinterface - -In the 'Probe Maps 2020Final.xlsx' -''' +The output folder is ready to be used as a probeinterface library and contains: +- one folder per probe +- inside each folder a json file and a figure png file +""" +import argparse +import json +import shutil +from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt +from tqdm.auto import tqdm from probeinterface.plotting import plot_probe -from probeinterface import generate_multi_columns_probe, combine_probes, write_probeinterface +from probeinterface import write_probeinterface, Probe -from pathlib import Path -import json -import shutil - - -# work_dir = r"C:\Users\Windows\Dropbox (Scripps Research)\2021-01-SpikeInterface_CambridgeNeurotech" -# work_dir = '.' -# work_dir = '/home/samuel/Documents/SpikeInterface/2021-03-01-probeinterface_CambridgeNeurotech/' -# work_dir = '/home/samuel/Documents/SpikeInterface/2022-05-20-probeinterface_CambridgeNeurotech/' -# work_dir = '/home/samuel/Documents/SpikeInterface/2022-10-18-probeinterface_CambridgeNeurotech/' -# work_dir = '/home/samuel/OwnCloudCNRS/probeinterface/2023-06-14-probeinterface-CambridgeNeurotech/' -# work_dir = '/home/samuel/OwnCloudCNRS/probeinterface/2023-10-30-probeinterface-CambridgeNeurotech/' -work_dir = '/home/samuel/NextcloudCNRS/probeinterface/2025-01-27-probeinterface-CambridgeNeurotech/' - - -library_folder = '/home/samuel/Documents/SpikeInterface/probeinterface_library/cambridgeneurotech/' +cn_logo = Path(__file__).parent / "CN-logo.jpg" -library_folder = Path(library_folder) - -work_dir = Path(work_dir).absolute() - -export_folder = work_dir / 'export_2025_01_27' -probe_map_file = work_dir / 'ProbeMaps_Final2023.xlsx' -probe_info_table_file = work_dir / 'ProbesDataBase_Final2023.csv' +parser = argparse.ArgumentParser(description="Generate CambridgeNeurotech probe library for probeinterface") +parser.add_argument( + "probe_tables_path", + type=str, + help="Path to the folder containing the CambridgeNeurotech probe tables CSV files from https://github.com/cambridge-neurotech/probe_maps", +) +parser.add_argument( + "--output-folder", type=str, default="./cambridgeneurotech", help="Output folder to save the generated probes" +) # graphing parameters -plt.rcParams['pdf.fonttype'] = 42 # to make sure it is recognize as true font in illustrator -plt.rcParams['svg.fonttype'] = 'none' # to make sure it is recognize as true font in illustrator - - -def convert_probe_shape(listCoord): - ''' - This is to convert reference point probe shape inputted in excel - as string 'x y x y x y that outline the shape of one shanck - and can be converted to an array to draw the porbe - ''' - listCoord = [float(s) for s in listCoord.split(' ')] - res = [[listCoord[i], listCoord[i + 1]] for i in range(len(listCoord) - 1)] - res = res[::2] - - return res - -def convert_contact_shape(listCoord): - ''' - This is to convert reference shift in electrodes - ''' - listCoord = [float(s) for s in listCoord.split(' ')] - return listCoord - -def get_contact_order(connector, probe_type): - """ - Get the channel index given a connector and a probe_type. - This will help to re-order the probe contact later on. - """ +plt.rcParams["pdf.fonttype"] = 42 # to make sure it is recognize as true font in illustrator +plt.rcParams["svg.fonttype"] = "none" # to make sure it is recognize as true font in illustrator - # first part of the function to open the proper connector based on connector name - # header [0,1] is used to create a mutliindex - df = pd.read_excel(probe_map_file, sheet_name=connector, header=[0,1]) - - # second part to get the proper channel in the - if probe_type == 'E-1' or probe_type == 'E-2': - probe_type = 'E-1 & E-2' - - if probe_type == 'P-1' or probe_type == 'P-2': - probe_type = 'P-1 & P-2' - - if probe_type == 'H3' or probe_type == 'L3': - probe_type = 'H3 & L3' - - if probe_type == 'H5' or probe_type == 'H9': - probe_type = 'H5 & H9' - - # print(df[probe_type]) - tmpList = [] - for i in df[probe_type].columns: - # print('i', i, len(df[probe_type].columns)) - if len(df[probe_type].columns) == 1: - tmpList = np.flip(df[probe_type].values.astype(int).flatten()) - else: - tmp = df[probe_type][i].values - tmp = tmp[~np.isnan(tmp)].astype(int) # get rid of nan and convert to integer - tmp = np.flip(tmp) # this flips the value to match index that goes from tip to headstage of the probe - # print('tmp', tmp) - tmpList = np.append(tmpList, tmp) - tmpList = tmpList.astype(int) - - # print('tmpList', tmpList) - return tmpList - - -def generate_CN_probe(probe_info, probeIdx): - """ - Generate a mono shank CN probe - """ - if probe_info['part'] == 'Fb' or probe_info['part'] == 'F': - probe = generate_multi_columns_probe( - num_columns=probe_info['electrode_cols_n'], - num_contact_per_column=[int(x) for x in convert_probe_shape(probe_info['electrode_rows_n'])[probeIdx]], - xpitch=float(probe_info['electrodeSpacingWidth_um']), - ypitch=probe_info['electrodeSpacingHeight_um'], - y_shift_per_column=convert_probe_shape(probe_info['electrode_yShiftCol'])[probeIdx], - contact_shapes=probe_info['ElectrodeShape'], - contact_shape_params={'width': probe_info['electrodeWidth_um'], 'height': probe_info['electrodeHeight_um']} - ) - probe.set_planar_contour(convert_probe_shape(probe_info['probeShape'])) - - else: - y_shift_per_column = convert_contact_shape(probe_info['electrode_yShiftCol']) - - if ' ' in probe_info['electrode_rows_n']: - num_contact_per_column = [ int(e) for e in probe_info['electrode_rows_n'].split(' ')] - assert len(y_shift_per_column) == len(num_contact_per_column) - else: - num_contact_per_column = int(probe_info['electrode_rows_n']) - - probe = generate_multi_columns_probe( - num_columns=probe_info['electrode_cols_n'], - num_contact_per_column=num_contact_per_column, - xpitch=float(probe_info['electrodeSpacingWidth_um']), - ypitch=probe_info['electrodeSpacingHeight_um'], - y_shift_per_column=convert_contact_shape(probe_info['electrode_yShiftCol']), - contact_shapes=probe_info['ElectrodeShape'], - contact_shape_params={'width': probe_info['electrodeWidth_um'], 'height': probe_info['electrodeHeight_um']} - ) - probe.set_planar_contour(convert_probe_shape(probe_info['probeShape'])) - - if type(probe_info['electrodesCustomPosition']) == str: - probe._contact_positions = np.array(convert_probe_shape(probe_info['electrodesCustomPosition'])) - - return probe - -def generate_CN_multi_shank(probe_info): - """ - Generate a multi shank probe - """ - sub_probes = [] - for probeIdx in range(probe_info['shanks_n']): - sub_probe = generate_CN_probe(probe_info, probeIdx) - sub_probe.move([probe_info['shankSpacing_um']*probeIdx, 0]) - sub_probes.append(sub_probe) - - multi_shank_probe = combine_probes(sub_probes) - return multi_shank_probe - - -def create_CN_figure(probe_name, probe): +def create_CN_figure(probe): """ Create custom figire for CN with custom colors + logo """ - fig, ax = plt.subplots() - fig.set_size_inches(18.5, 10.5) + if probe.contact_sides is not None: + fig, axs = plt.subplots(ncols=2) + fig.set_size_inches(18.5, 10.5) + else: + fig, ax = plt.subplots() + fig.set_size_inches(18.5, 10.5) + axs = [ax] n = probe.get_contact_count() - plot_probe(probe, ax=ax, - contacts_colors = ['#5bc5f2'] * n, # made change to default color - probe_shape_kwargs = dict(facecolor='#6f6f6e', edgecolor='k', lw=0.5, alpha=0.3), # made change to default color - with_contact_id=True) - - ax.set_xlabel(u'Width (\u03bcm)') #modif to legend - ax.set_ylabel(u'Height (\u03bcm)') #modif to legend - ax.spines['right'].set_visible(False) #remove external axis - ax.spines['top'].set_visible(False) #remove external axis - - ax.set_title('\n' +'CambridgeNeuroTech' +'\n'+ probe.annotations.get('model_name'), fontsize = 24) - - fig.tight_layout() #modif tight layout - - im = plt.imread(work_dir / 'CN_logo-01.jpg') - newax = fig.add_axes([0.8,0.85,0.2,0.1], anchor='NW', zorder=0) + probe_max_height = np.max(probe.contact_positions[:, 1]) + if probe.contact_sides is not None: + for i, side in enumerate(("front", "back")): + ax = axs[i] + plot_probe( + probe, + ax=ax, + contacts_colors=["#5bc5f2"] * n, # made change to default color + probe_shape_kwargs=dict( + facecolor="#6f6f6e", edgecolor="k", lw=0.5, alpha=0.3 + ), # made change to default color + with_contact_id=True, + side=side, + ) + ax.set_title(f"Side: {side}", fontsize=20) + else: + plot_probe( + probe, + ax=axs[0], + contacts_colors=["#5bc5f2"] * n, # made change to default color + probe_shape_kwargs=dict( + facecolor="#6f6f6e", edgecolor="k", lw=0.5, alpha=0.3 + ), # made change to default color + with_contact_id=True, + ) + axs[0].set_title("") + + for ax in axs: + y_min = ax.get_ylim()[0] + y_max = probe_max_height + 200 + ax.set_ylim(y_min, y_max) + ax.set_xlabel("Width (\u03bcm)") # modify to legend + ax.set_ylabel("Height (\u03bcm)") # modify to legend + ax.spines["right"].set_visible(False) # remove external axis + ax.spines["top"].set_visible(False) # remove external axis + + fig.suptitle("\n" + "CambridgeNeuroTech" + "\n" + probe.model_name, fontsize=24) + + fig.tight_layout() + + im = plt.imread(str(cn_logo)) + newax = fig.add_axes([0.8, 0.85, 0.2, 0.1], anchor="NW", zorder=0) newax.imshow(im) - newax.axis('off') + newax.axis("off") return fig -def export_one_probe(probe_name, probe): +def export_one_probe(probe_name, probe, output_folder): """ - Save one probe in "export_folder" + figure. + Save one probe in "output_folder" + figure. """ - probe_folder = export_folder / probe_name + probe_folder = output_folder / probe_name probe_folder.mkdir(exist_ok=True, parents=True) - probe_file = probe_folder / (probe_name + '.json') - figure_file = probe_folder / (probe_name + '.png') + probe_file = probe_folder / (probe_name + ".json") + figure_file = probe_folder / (probe_name + ".png") write_probeinterface(probe_file, probe) - fig = create_CN_figure(probe_name, probe) + fig = create_CN_figure(probe) fig.savefig(figure_file) - # plt.show() - # avoid memory error plt.close(fig) -def generate_all_probes(): - """ - Main function. - Generate all probes. - """ - probe_info_table = pd.read_csv(probe_info_table_file) - #~ print(probe_info_list) +def is_contour_correct(probe): + from shapely.geometry import Point, Polygon - for i, probe_info in probe_info_table.iterrows(): - print(i, probe_info['part']) - - # DEBUG - # if not probe_info['part'] in ('P-1', 'P-2'): - # continue - - # print(probe_info) - - if probe_info['shanks_n'] == 1: - # one shank - probe_unordered = generate_CN_probe(probe_info, 0) - else: - # multi shank - probe_unordered = generate_CN_multi_shank(probe_info) + polygon = Polygon(probe.probe_planar_contour) - # loop over connector case that re order the probe contact index - for connector in list(probe_info[probe_info.index.str.contains('ASSY')].dropna().index): - probe_name = connector+'-'+probe_info['part'] + for i, contact_pos in enumerate(probe.contact_positions): + width = probe.contact_shape_params[i]["width"] + height = probe.contact_shape_params[i]["height"] + points = [ + (contact_pos[0] - width / 2, contact_pos[1] - height / 2), + (contact_pos[0] + width / 2, contact_pos[1] - height / 2), + (contact_pos[0] + width / 2, contact_pos[1] + height / 2), + (contact_pos[0] - width / 2, contact_pos[1] + height / 2), + ] + for point in points: + p = Point(point[0], point[1]) + if not polygon.contains(p): + return False + return True - # DEBUG - # if connector != 'ASSY-1': - # continue +def generate_all_probes(probe_tables_path, output_folder): + sheet_names = list(pd.read_excel(probe_tables_path / "probe_contacts.xlsx", sheet_name=None).keys()) - print(' ', probe_name) - - contact_order = get_contact_order(connector = connector, probe_type = probe_info['part']) - - # print(probe_unordered) - # print(probe_unordered.contact_ids) - # print(contact_order) - # print(probe_unordered.) - # fig, ax = plt.subplots() - # plot_probe(probe_unordered, ax=ax, with_contact_id=True) - # plt.show() - - - - sorted_indices = np.argsort(contact_order) - probe = probe_unordered.get_slice(sorted_indices) - - probe.annotate(model_name=probe_name, manufacturer='cambridgeneurotech') - - # one based in cambridge neurotech - contact_ids = np.arange(sorted_indices.size) + 1 - contact_ids = contact_ids.astype(str) - probe.set_contact_ids(contact_ids) - - export_one_probe(probe_name, probe) - - # break - -def synchronize_library(): - - for source_probe_file in export_folder.glob('**/*.json'): - # print() - print(source_probe_file.stem) - target_probe_file = library_folder / source_probe_file.parent.stem / source_probe_file.name - # print(target_probe_file) - with open(source_probe_file, mode='r')as source: - source_dict = json.load(source) - - with open(target_probe_file, mode='r')as target: - target_dict = json.load(target) - - source_dict.pop('version') - - target_dict.pop('version') - - # this was needed between version 0.2.17 > 0.2.18 - # target_dict["probes"][0]["annotations"].pop("first_index") - - same = source_dict == target_dict - - # copy the json - shutil.copyfile(source_probe_file, target_probe_file) - if not same: - # copy the png - shutil.copyfile(source_probe_file.parent / (source_probe_file.stem + '.png'), - target_probe_file.parent / (target_probe_file.stem + '.png') ) + wrong_contours = [] + sheets_with_issues = [] + for sheet_name in tqdm(sheet_names, "Exporting CN probes"): + contacts = pd.read_excel(probe_tables_path / "probe_contacts.xlsx", sheet_name=sheet_name) + contour = pd.read_excel(probe_tables_path / "probe_contours.xlsx", sheet_name=sheet_name) + if np.all(pd.isna(contacts["contact_sides"])): + contacts.drop(columns="contact_sides", inplace=True) + else: + print(f"Double sided probe: {sheet_name}") + if "z" in contacts.columns: + contacts.drop(columns=["z"], inplace=True) + try: + probe = Probe.from_dataframe(contacts) + probe.manufacturer = "cambridgeneurotech" + probe.model_name = sheet_name + probe.set_planar_contour(contour) + if not is_contour_correct(probe): + wrong_contours.append(sheet_name) + export_one_probe(sheet_name, probe, output_folder) - # library_folder + except Exception as e: + print(f"Problem loading {sheet_name}: {e}") + sheets_with_issues.append(sheet_name) + print("Wrong contours:\n\n", wrong_contours) + print("Sheets with issues:\n\n", sheets_with_issues) -if __name__ == '__main__': - generate_all_probes() - synchronize_library() +if __name__ == "__main__": + args = parser.parse_args() + probe_tables_path = Path(args.probe_tables_path) + output_folder = Path(args.output_folder) + if output_folder.exists(): + shutil.rmtree(output_folder) + output_folder.mkdir(parents=True, exist_ok=True) + generate_all_probes(probe_tables_path, output_folder) diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index 2830ca0..ce535cf 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -102,6 +102,7 @@ def plot_probe( ylims: tuple | None = None, zlims: tuple | None = None, show_channel_on_click: bool = False, + side=None, ): """Plot a Probe object. Generates a 2D or 3D axis, depending on Probe.ndim @@ -138,6 +139,8 @@ def plot_probe( Limits for z dimension show_channel_on_click : bool, default: False If True, the channel information is shown upon click + side : None | "front" | "back" + If the probe is two side, then the side must be given otherwise this raises an error. Returns ------- @@ -148,6 +151,15 @@ def plot_probe( """ import matplotlib.pyplot as plt + if probe.contact_sides is not None: + if side is None or side not in ("front", "back"): + raise ValueError( + "The probe has two side, you must give which one to plot. plot_probe(probe, side='front'|'back')" + ) + mask = probe.contact_sides == side + probe = probe.get_slice(mask) + probe._contact_sides = None + if ax is None: if probe.ndim == 2: fig, ax = plt.subplots() diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index fb7ac24..2572708 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -101,6 +101,7 @@ def __init__( self.probe_planar_contour = None # This handles the shank id per contact + # If None then one shank only self._shank_ids = None # This handles the wiring to device : channel index on device side. @@ -112,6 +113,10 @@ def __init__( # This must be unique at Probe AND ProbeGroup level self._contact_ids = None + # Handle contact side for double face probes + # If None then one face only + self._contact_sides = None + # annotation: a dict that contains all meta information about # the probe (name, manufacturor, date of production, ...) self.annotations = dict() @@ -153,6 +158,10 @@ def contact_ids(self): def shank_ids(self): return self._shank_ids + @property + def contact_sides(self): + return self._contact_sides + @property def name(self): return self.annotations.get("name", None) @@ -237,6 +246,8 @@ def get_title(self) -> str: if self.shank_ids is not None: num_shank = self.get_shank_count() txt += f" - {num_shank}shanks" + if self._contact_sides is not None: + txt += f" - 2 sides" return txt def __repr__(self): @@ -291,7 +302,14 @@ def get_shank_count(self) -> int: return n def set_contacts( - self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, contact_ids=None, shank_ids=None + self, + positions, + shapes="circle", + shape_params={"radius": 10}, + plane_axes=None, + contact_ids=None, + shank_ids=None, + contact_sides=None, ): """Sets contacts to a Probe. @@ -320,16 +338,28 @@ def set_contacts( shank_ids : array[str] | None, default: None Defines the shank ids for the contacts. If None, then these are assigned to a unique Shank. + contact_sides : array[str] | None, default: None + If probe is double sided, defines sides by a vector of ['front' | 'back'] """ positions = np.array(positions) if positions.shape[1] != self.ndim: raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!") - # Check for duplicate positions - unique_positions = np.unique(positions, axis=0) - positions_are_not_unique = unique_positions.shape[0] != positions.shape[0] - if positions_are_not_unique: - _raise_non_unique_positions_error(positions) + if contact_sides is None: + # Check for duplicate positions + unique_positions = np.unique(positions, axis=0) + positions_are_not_unique = unique_positions.shape[0] != positions.shape[0] + if positions_are_not_unique: + _raise_non_unique_positions_error(positions) + else: + # Check for duplicate positions side by side + contact_sides = np.asarray(contact_sides).astype(str) + for side in ("front", "back"): + mask = contact_sides == side + unique_positions = np.unique(positions[mask], axis=0) + positions_are_not_unique = unique_positions.shape[0] != positions[mask].shape[0] + if positions_are_not_unique: + _raise_non_unique_positions_error(positions[mask]) self._contact_positions = positions n = positions.shape[0] @@ -356,6 +386,15 @@ def set_contacts( if self.shank_ids.size != n: raise ValueError(f"shank_ids have wrong size: {self.shanks.ids.size} != {n}") + if contact_sides is None: + self._contact_sides = contact_sides + else: + self._contact_sides = contact_sides + if self._contact_sides.size != n: + raise ValueError(f"contact_sides have wrong size: {self._contact_sides.ids.size} != {n}") + if not np.all(np.isin(self._contact_sides, ["front", "back"])): + raise ValueError(f"contact_sides must 'front' or 'back'") + # shape if isinstance(shapes, str): shapes = [shapes] * n @@ -592,6 +631,13 @@ def __eq__(self, other): ): return False + if self._contact_sides is None: + if other._contact_sides is not None: + return False + else: + if not np.array_equal(self._contact_sides, other._contact_sides): + return False + # Compare contact_annotations dictionaries if self.contact_annotations.keys() != other.contact_annotations.keys(): return False @@ -842,6 +888,7 @@ def rotate_contacts(self, thetas: float | np.array[float] | list[float]): "device_channel_indices", "_contact_ids", "_shank_ids", + "_contact_sides", ] def to_dict(self, array_as_list: bool = False) -> dict: @@ -895,6 +942,9 @@ def from_dict(d: dict) -> "Probe": plane_axes=d["contact_plane_axes"], shapes=d["contact_shapes"], shape_params=d["contact_shape_params"], + contact_ids=d.get("contact_ids", None), + shank_ids=d.get("shank_ids", None), + contact_sides=d.get("contact_sides", None), ) v = d.get("probe_planar_contour", None) @@ -905,14 +955,6 @@ def from_dict(d: dict) -> "Probe": if v is not None: probe.set_device_channel_indices(v) - v = d.get("shank_ids", None) - if v is not None: - probe.set_shank_ids(v) - - v = d.get("contact_ids", None) - if v is not None: - probe.set_contact_ids(v) - if "annotations" in d: probe.annotate(**d["annotations"]) if "contact_annotations" in d: @@ -955,6 +997,7 @@ def to_numpy(self, complete: bool = False) -> np.array: ... ('shank_ids', 'U64'), ('contact_ids', 'U64'), + ('contact_sides', 'U8'), # The rest is added only if `complete=True` ('device_channel_indices', 'int64', optional), @@ -991,6 +1034,11 @@ def to_numpy(self, complete: bool = False) -> np.array: dtype += [(k, "float64")] dtype += [("shank_ids", "U64"), ("contact_ids", "U64")] + if self._contact_sides is not None: + dtype += [ + ("contact_sides", "U8"), + ] + if complete: dtype += [("device_channel_indices", "int64")] dtype += [("si_units", "U64")] @@ -1014,6 +1062,9 @@ def to_numpy(self, complete: bool = False) -> np.array: arr["shank_ids"] = self.shank_ids + if self._contact_sides is not None: + arr["contact_sides"] = self.contact_sides + if self.contact_ids is None: arr["contact_ids"] = [""] * self.get_contact_count() else: @@ -1062,6 +1113,7 @@ def from_numpy(arr: np.ndarray) -> "Probe": "contact_shapes", "shank_ids", "contact_ids", + "contact_sides", "device_channel_indices", "radius", "width", @@ -1118,14 +1170,22 @@ def from_numpy(arr: np.ndarray) -> "Probe": else: plane_axes = None - probe.set_contacts(positions=positions, plane_axes=plane_axes, shapes=shapes, shape_params=shape_params) + shank_ids = arr["shank_ids"] if "shank_ids" in fields else None + contact_sides = arr["contact_sides"] if "contact_sides" in fields else None + + probe.set_contacts( + positions=positions, + plane_axes=plane_axes, + shapes=shapes, + shape_params=shape_params, + shank_ids=shank_ids, + contact_sides=contact_sides, + ) if "device_channel_indices" in fields: dev_channel_indices = arr["device_channel_indices"] if not np.all(dev_channel_indices == -1): probe.set_device_channel_indices(dev_channel_indices) - if "shank_ids" in fields: - probe.set_shank_ids(arr["shank_ids"]) if "contact_ids" in fields: probe.set_contact_ids(arr["contact_ids"]) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index f19c0d7..eaa676f 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -50,7 +50,30 @@ def test_plot_probegroup(): plot_probegroup(probegroup_3d, same_axes=True) +def test_plot_probe_two_side(): + probe = Probe() + probe.set_contacts( + positions=np.array( + [ + [0, 0], + [0, 10], + [0, 20], + [0, 0], + [0, 10], + [0, 20], + ] + ), + shapes="circle", + contact_ids=["F1", "F2", "F3", "B1", "B2", "B3"], + contact_sides=["front", "front", "front", "back", "back", "back"], + ) + + plot_probe(probe, with_contact_id=True, side="front") + plot_probe(probe, with_contact_id=True, side="back") + + if __name__ == "__main__": - test_plot_probe() + # test_plot_probe() # test_plot_probe_group() + test_plot_probe_two_side() plt.show() diff --git a/tests/test_probe.py b/tests/test_probe.py index b20ea0d..48a3b82 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -197,9 +197,42 @@ def test_position_uniqueness(): probe.set_contacts(positions=positions_with_dups, shapes="circle", shape_params={"radius": 5}) +def test_double_side_probe(): + + probe = Probe() + probe.set_contacts( + positions=np.array( + [ + [0, 0], + [0, 10], + [0, 20], + [0, 0], + [0, 10], + [0, 20], + ] + ), + shapes="circle", + contact_sides=["front", "front", "front", "back", "back", "back"], + ) + print(probe) + + assert "contact_sides" in probe.to_dict() + + probe2 = Probe.from_dict(probe.to_dict()) + assert probe2 == probe + + probe3 = Probe.from_numpy(probe.to_numpy()) + assert probe3 == probe + + probe4 = Probe.from_dataframe(probe.to_dataframe()) + assert probe4 == probe + + if __name__ == "__main__": test_probe() tmp_path = Path("tmp") tmp_path.mkdir(exist_ok=True) test_save_to_zarr(tmp_path) + + test_double_side_probe()