diff --git a/neo/io/__init__.py b/neo/io/__init__.py index de2a7c812..4a1a121db 100644 --- a/neo/io/__init__.py +++ b/neo/io/__init__.py @@ -44,6 +44,7 @@ * :attr:`NeuroScopeIO` * :attr:`NeuroshareIO` * :attr:`NixIO` +* :attr:`NWBIO` * :attr:`OpenEphysIO` * :attr:`OpenEphysBinaryIO` * :attr:`PhyIO` @@ -184,6 +185,10 @@ .. autoattribute:: extensions +.. autoclass:: neo.io.NWBIO + + .. autoattribute:: extensions + .. autoclass:: neo.io.OpenEphysIO .. autoattribute:: extensions @@ -292,6 +297,7 @@ from neo.io.neuroscopeio import NeuroScopeIO from neo.io.nixio import NixIO from neo.io.nixio_fr import NixIO as NixIOFr +from neo.io.nwbio import NWBIO from neo.io.openephysio import OpenEphysIO from neo.io.openephysbinaryio import OpenEphysBinaryIO from neo.io.phyio import PhyIO @@ -340,6 +346,7 @@ NeuroExplorerIO, NeuroScopeIO, NeuroshareIO, + NWBIO, OpenEphysIO, OpenEphysBinaryIO, PhyIO, diff --git a/neo/io/asciisignalio.py b/neo/io/asciisignalio.py index 54ee6b788..addf310fa 100644 --- a/neo/io/asciisignalio.py +++ b/neo/io/asciisignalio.py @@ -195,7 +195,8 @@ def read_segment(self, lazy=False): delimiter=self.delimiter, usecols=self.usecols, skip_header=self.skiprows, - dtype='f') + dtype='f', + invalid_raise=False) if len(sig.shape) == 1: sig = sig[:, np.newaxis] elif self.method == 'csv': diff --git a/neo/io/nwbio.py b/neo/io/nwbio.py new file mode 100644 index 000000000..91889273d --- /dev/null +++ b/neo/io/nwbio.py @@ -0,0 +1,858 @@ +""" +NWBIO +===== + +IO class for reading data from a Neurodata Without Borders (NWB) dataset + +Documentation : https://www.nwb.org/ +Depends on: h5py, nwb, dateutil +Supported: Read, Write +Python API - https://pynwb.readthedocs.io +Sample datasets from CRCNS - https://crcns.org/NWB +Sample datasets from Allen Institute +- http://alleninstitute.github.io/AllenSDK/cell_types.html#neurodata-without-borders +""" + +from __future__ import absolute_import, division + +import json +import logging +import os +from collections import defaultdict +from itertools import chain +from json.decoder import JSONDecodeError + +import numpy as np +import quantities as pq + +from neo.core import (Segment, SpikeTrain, Epoch, Event, AnalogSignal, + IrregularlySampledSignal, Block, ImageSequence) +from neo.io.baseio import BaseIO +from neo.io.proxyobjects import ( + AnalogSignalProxy as BaseAnalogSignalProxy, + EventProxy as BaseEventProxy, + EpochProxy as BaseEpochProxy, + SpikeTrainProxy as BaseSpikeTrainProxy +) + +# PyNWB imports +try: + import pynwb + from pynwb import NWBFile, TimeSeries + from pynwb.base import ProcessingModule + from pynwb.ecephys import ElectricalSeries, Device, EventDetection + from pynwb.behavior import SpatialSeries + from pynwb.misc import AnnotationSeries + from pynwb import image + from pynwb.image import ImageSeries + from pynwb.spec import NWBAttributeSpec, NWBDatasetSpec, NWBGroupSpec, NWBNamespace, \ + NWBNamespaceBuilder + from pynwb.device import Device + # For calcium imaging data + from pynwb.ophys import TwoPhotonSeries, OpticalChannel, ImageSegmentation, Fluorescence + + have_pynwb = True +except ImportError: + have_pynwb = False + +# hdmf imports +try: + from hdmf.spec import (LinkSpec, GroupSpec, DatasetSpec, SpecNamespace, + NamespaceBuilder, AttributeSpec, DtypeSpec, RefSpec) + + have_hdmf = True +except ImportError: + have_hdmf = False +except SyntaxError: + have_hdmf = False + +logger = logging.getLogger("Neo") + +GLOBAL_ANNOTATIONS = ( + "session_start_time", "identifier", "timestamps_reference_time", "experimenter", + "experiment_description", "session_id", "institution", "keywords", "notes", + "pharmacology", "protocol", "related_publications", "slices", "source_script", + "source_script_file_name", "data_collection", "surgery", "virus", "stimulus_notes", + "lab", "session_description" +) + +POSSIBLE_JSON_FIELDS = ( + "source_script", "description" +) + +prefix_map = { + 1e9: 'giga', + 1e6: 'mega', + 1e3: 'kilo', + 1: '', + 1e-3: 'milli', + 1e-6: 'micro', + 1e-9: 'nano', + 1e-12: 'pico' +} + + +def try_json_field(content): + """ + Try to interpret a string as JSON data. + + If successful, return the JSON data (dict or list) + If unsuccessful, return the original string + """ + try: + return json.loads(content) + except JSONDecodeError: + return content + + +def get_class(module, name): + """ + Given a module path and a class name, return the class object + """ + module_path = module.split(".") + assert len(module_path) == 2 # todo: handle the general case where this isn't 2 + return getattr(getattr(pynwb, module_path[1]), name) + + +def statistics(block): # todo: move this to be a property of Block + """ + Return simple statistics about a Neo Block. + """ + stats = { + "SpikeTrain": {"count": 0}, + "AnalogSignal": {"count": 0}, + "IrregularlySampledSignal": {"count": 0}, + "Epoch": {"count": 0}, + "Event": {"count": 0}, + } + for segment in block.segments: + stats["SpikeTrain"]["count"] += len(segment.spiketrains) + stats["AnalogSignal"]["count"] += len(segment.analogsignals) + stats["IrregularlySampledSignal"]["count"] += len(segment.irregularlysampledsignals) + stats["Epoch"]["count"] += len(segment.epochs) + stats["Event"]["count"] += len(segment.events) + return stats + + +def get_units_conversion(signal, timeseries_class): + """ + Given a quantity array and a TimeSeries subclass, return + the conversion factor and the expected units + """ + # it would be nice if the expected units was an attribute of the PyNWB class + if "CurrentClamp" in timeseries_class.__name__: + expected_units = pq.volt + elif "VoltageClamp" in timeseries_class.__name__: + expected_units = pq.ampere + else: + # todo: warn that we don't handle this subclass yet + expected_units = signal.units + return float((signal.units / expected_units).simplified.magnitude), expected_units + + +def time_in_seconds(t): + return float(t.rescale("second")) + + +def _decompose_unit(unit): + """ + Given a quantities unit object, return a base unit name and a conversion factor. + + Example: + + >>> _decompose_unit(pq.mV) + ('volt', 0.001) + """ + assert isinstance(unit, pq.quantity.Quantity) + assert unit.magnitude == 1 + conversion = 1.0 + + def _decompose(unit): + dim = unit.dimensionality + if len(dim) != 1: + raise NotImplementedError("Compound units not yet supported") # e.g. volt-metre + uq, n = list(dim.items())[0] + if n != 1: + raise NotImplementedError("Compound units not yet supported") # e.g. volt^2 + uq_def = uq.definition + return float(uq_def.magnitude), uq_def + + conv, unit2 = _decompose(unit) + while conv != 1: + conversion *= conv + unit = unit2 + conv, unit2 = _decompose(unit) + return list(unit.dimensionality.keys())[0].name, conversion + + +def _recompose_unit(base_unit_name, conversion): + """ + Given a base unit name and a conversion factor, return a quantities unit object + + Example: + + >>> _recompose_unit("ampere", 1e-9) + UnitCurrent('nanoampere', 0.001 * uA, 'nA') + + """ + unit_name = None + for cf in prefix_map: + # conversion may have a different float precision to the keys in + # prefix_map, so we can't just use `prefix_map[conversion]` + if abs(conversion - cf) / cf < 1e-6: + unit_name = prefix_map[cf] + base_unit_name + if unit_name is None: + raise ValueError(f"Can't handle this conversion factor: {conversion}") + + if unit_name[-1] == "s": # strip trailing 's', e.g. "volts" --> "volt" + unit_name = unit_name[:-1] + try: + return getattr(pq, unit_name) + except AttributeError: + logger.warning(f"Can't handle unit '{unit_name}'. Returning dimensionless") + return pq.dimensionless + + +class NWBIO(BaseIO): + """ + Class for "reading" experimental data from a .nwb file, and "writing" a .nwb file from Neo + """ + supported_objects = [Block, Segment, AnalogSignal, IrregularlySampledSignal, + SpikeTrain, Epoch, Event, ImageSequence] + readable_objects = supported_objects + writeable_objects = supported_objects + + has_header = False + support_lazy = True + + name = 'NeoNWB IO' + description = 'This IO reads/writes experimental data from/to an .nwb dataset' + extensions = ['nwb'] + mode = 'one-file' + + is_readable = True + is_writable = True + is_streameable = False + + def __init__(self, filename, mode='r'): + """ + Arguments: + filename : the filename + """ + if not have_pynwb: + raise Exception("Please install the pynwb package to use NWBIO") + if not have_hdmf: + raise Exception("Please install the hdmf package to use NWBIO") + BaseIO.__init__(self, filename=filename) + self.filename = filename + self.blocks_written = 0 + self.nwb_file_mode = mode + + def read_all_blocks(self, lazy=False, **kwargs): + """ + Load all blocks in the file. + """ + assert self.nwb_file_mode in ('r',) + io = pynwb.NWBHDF5IO(self.filename, mode=self.nwb_file_mode, + load_namespaces=True) # Open a file with NWBHDF5IO + self._file = io.read() + + self.global_block_metadata = {} + for annotation_name in GLOBAL_ANNOTATIONS: + value = getattr(self._file, annotation_name, None) + if value is not None: + if annotation_name in POSSIBLE_JSON_FIELDS: + value = try_json_field(value) + self.global_block_metadata[annotation_name] = value + if "session_description" in self.global_block_metadata: + self.global_block_metadata["description"] = self.global_block_metadata[ + "session_description"] + self.global_block_metadata["file_origin"] = self.filename + if "session_start_time" in self.global_block_metadata: + self.global_block_metadata["rec_datetime"] = self.global_block_metadata[ + "session_start_time"] + if "file_create_date" in self.global_block_metadata: + self.global_block_metadata["file_datetime"] = self.global_block_metadata[ + "file_create_date"] + + self._blocks = {} + self._read_acquisition_group(lazy=lazy) + self._read_stimulus_group(lazy) + self._read_units(lazy=lazy) + self._read_epochs_group(lazy) + + return list(self._blocks.values()) + + def read_block(self, lazy=False, block_index=0, **kargs): + """ + Load the first block in the file. + """ + return self.read_all_blocks(lazy=lazy)[block_index] + + def _get_segment(self, block_name, segment_name): + # If we've already created a Block with the given name return it, + # otherwise create it now and store it in self._blocks. + # If we've already created a Segment in the given block, return it, + # otherwise create it now and return it. + if block_name in self._blocks: + block = self._blocks[block_name] + else: + block = Block(name=block_name, **self.global_block_metadata) + self._blocks[block_name] = block + segment = None + for seg in block.segments: + if segment_name == seg.name: + segment = seg + break + if segment is None: + segment = Segment(name=segment_name) + segment.block = block + block.segments.append(segment) + return segment + + def _read_epochs_group(self, lazy): + if self._file.epochs is not None: + try: + # NWB files created by Neo store the segment, block and epoch names as extra + # columns + segment_names = self._file.epochs.segment[:] + block_names = self._file.epochs.block[:] + epoch_names = self._file.epochs._name[:] + except AttributeError: + epoch_names = None + + if epoch_names is not None: + unique_epoch_names = np.unique(epoch_names) + for epoch_name in unique_epoch_names: + index, = np.where((epoch_names == epoch_name)) + epoch = EpochProxy(self._file.epochs, epoch_name, index) + if not lazy: + epoch = epoch.load() + segment_name = np.unique(segment_names[index]) + block_name = np.unique(block_names[index]) + assert segment_name.size == block_name.size == 1 + segment = self._get_segment(block_name[0], segment_name[0]) + segment.epochs.append(epoch) + epoch.segment = segment + else: + epoch = EpochProxy(self._file.epochs) + if not lazy: + epoch = epoch.load() + segment = self._get_segment("default", "default") + segment.epochs.append(epoch) + epoch.segment = segment + + def _read_timeseries_group(self, group_name, lazy): + group = getattr(self._file, group_name) + for timeseries in group.values(): + try: + # NWB files created by Neo store the segment and block names in the comments field + hierarchy = json.loads(timeseries.comments) + except JSONDecodeError: + # For NWB files created with other applications, we put everything in a single + # segment in a single block + # todo: investigate whether there is a reliable way to create multiple segments, + # e.g. using Trial information + block_name = "default" + segment_name = "default" + else: + block_name = hierarchy["block"] + segment_name = hierarchy["segment"] + segment = self._get_segment(block_name, segment_name) + if isinstance(timeseries, AnnotationSeries): + event = EventProxy(timeseries, group_name) + if not lazy: + event = event.load() + segment.events.append(event) + event.segment = segment + elif timeseries.rate: # AnalogSignal + signal = AnalogSignalProxy(timeseries, group_name) + if not lazy: + signal = signal.load() + segment.analogsignals.append(signal) + signal.segment = segment + else: # IrregularlySampledSignal + signal = AnalogSignalProxy(timeseries, group_name) + if not lazy: + signal = signal.load() + segment.irregularlysampledsignals.append(signal) + signal.segment = segment + + def _read_units(self, lazy): + if self._file.units: + for id in range(len(self._file.units)): + try: + # NWB files created by Neo store the segment and block names as extra columns + segment_name = self._file.units.segment[id] + block_name = self._file.units.block[id] + except AttributeError: + # For NWB files created with other applications, we put everything in a single + # segment in a single block + segment_name = "default" + block_name = "default" + segment = self._get_segment(block_name, segment_name) + spiketrain = SpikeTrainProxy(self._file.units, id) + if not lazy: + spiketrain = spiketrain.load() + segment.spiketrains.append(spiketrain) + spiketrain.segment = segment + + def _read_acquisition_group(self, lazy): + self._read_timeseries_group("acquisition", lazy) + + def _read_stimulus_group(self, lazy): + self._read_timeseries_group("stimulus", lazy) + + def write_all_blocks(self, blocks, **kwargs): + """ + Write list of blocks to the file + """ + # todo: allow metadata in NWBFile constructor to be taken from kwargs + annotations = defaultdict(set) + for annotation_name in GLOBAL_ANNOTATIONS: + if annotation_name in kwargs: + annotations[annotation_name] = kwargs[annotation_name] + else: + for block in blocks: + if annotation_name in block.annotations: + try: + annotations[annotation_name].add(block.annotations[annotation_name]) + except TypeError: + if annotation_name in POSSIBLE_JSON_FIELDS: + encoded = json.dumps(block.annotations[annotation_name]) + annotations[annotation_name].add(encoded) + else: + raise + if annotation_name in annotations: + if len(annotations[annotation_name]) > 1: + raise NotImplementedError( + "We don't yet support multiple values for {}".format(annotation_name)) + # take single value from set + annotations[annotation_name], = annotations[annotation_name] + if "identifier" not in annotations: + annotations["identifier"] = self.filename + if "session_description" not in annotations: + annotations["session_description"] = blocks[0].description or self.filename + # todo: concatenate descriptions of multiple blocks if different + if "session_start_time" not in annotations: + raise Exception("Writing to NWB requires an annotation 'session_start_time'") + # todo: handle subject + # todo: store additional Neo annotations somewhere in NWB file + nwbfile = NWBFile(**annotations) + + assert self.nwb_file_mode in ('w',) # possibly expand to 'a'ppend later + if self.nwb_file_mode == "w" and os.path.exists(self.filename): + os.remove(self.filename) + io_nwb = pynwb.NWBHDF5IO(self.filename, mode=self.nwb_file_mode) + + if sum(statistics(block)["SpikeTrain"]["count"] for block in blocks) > 0: + nwbfile.add_unit_column('_name', 'the name attribute of the SpikeTrain') + # nwbfile.add_unit_column('_description', + # 'the description attribute of the SpikeTrain') + nwbfile.add_unit_column( + 'segment', 'the name of the Neo Segment to which the SpikeTrain belongs') + nwbfile.add_unit_column( + 'block', 'the name of the Neo Block to which the SpikeTrain belongs') + + if sum(statistics(block)["Epoch"]["count"] for block in blocks) > 0: + nwbfile.add_epoch_column('_name', 'the name attribute of the Epoch') + # nwbfile.add_epoch_column('_description', 'the description attribute of the Epoch') + nwbfile.add_epoch_column( + 'segment', 'the name of the Neo Segment to which the Epoch belongs') + nwbfile.add_epoch_column('block', + 'the name of the Neo Block to which the Epoch belongs') + + for i, block in enumerate(blocks): + self.write_block(nwbfile, block) + io_nwb.write(nwbfile) + io_nwb.close() + + with pynwb.NWBHDF5IO(self.filename, "r") as io_validate: + errors = pynwb.validate(io_validate, namespace="core") + if errors: + raise Exception(f"Errors found when validating {self.filename}") + + def write_block(self, nwbfile, block, **kwargs): + """ + Write a Block to the file + :param block: Block to be written + :param nwbfile: Representation of an NWB file + """ + electrodes = self._write_electrodes(nwbfile, block) + if not block.name: + block.name = "block%d" % self.blocks_written + for i, segment in enumerate(block.segments): + assert segment.block is block + if not segment.name: + segment.name = "%s : segment%d" % (block.name, i) + self._write_segment(nwbfile, segment, electrodes) + self.blocks_written += 1 + + def _write_electrodes(self, nwbfile, block): + # this handles only icephys_electrode for now + electrodes = {} + devices = {} + for segment in block.segments: + for signal in chain(segment.analogsignals, segment.irregularlysampledsignals): + if "nwb_electrode" in signal.annotations: + elec_meta = signal.annotations["nwb_electrode"].copy() + if elec_meta["name"] not in electrodes: + # todo: check for consistency if the name is already there + if elec_meta["device"]["name"] in devices: + device = devices[elec_meta["device"]["name"]] + else: + device = nwbfile.create_device(**elec_meta["device"]) + devices[elec_meta["device"]["name"]] = device + elec_meta.pop("device") + electrodes[elec_meta["name"]] = nwbfile.create_icephys_electrode( + device=device, **elec_meta + ) + return electrodes + + def _write_segment(self, nwbfile, segment, electrodes): + # maybe use NWB trials to store Segment metadata? + for i, signal in enumerate( + chain(segment.analogsignals, segment.irregularlysampledsignals)): + assert signal.segment is segment + if not signal.name: + signal.name = "%s : analogsignal%d" % (segment.name, i) + self._write_signal(nwbfile, signal, electrodes) + + for i, train in enumerate(segment.spiketrains): + assert train.segment is segment + if not train.name: + train.name = "%s : spiketrain%d" % (segment.name, i) + self._write_spiketrain(nwbfile, train) + + for i, event in enumerate(segment.events): + assert event.segment is segment + if not event.name: + event.name = "%s : event%d" % (segment.name, i) + self._write_event(nwbfile, event) + + for i, epoch in enumerate(segment.epochs): + if not epoch.name: + epoch.name = "%s : epoch%d" % (segment.name, i) + self._write_epoch(nwbfile, epoch) + + def _write_signal(self, nwbfile, signal, electrodes): + hierarchy = {'block': signal.segment.block.name, 'segment': signal.segment.name} + if "nwb_neurodata_type" in signal.annotations: + timeseries_class = get_class(*signal.annotations["nwb_neurodata_type"]) + else: + timeseries_class = TimeSeries # default + additional_metadata = {name[4:]: value + for name, value in signal.annotations.items() + if name.startswith("nwb:")} + if "nwb_electrode" in signal.annotations: + electrode_name = signal.annotations["nwb_electrode"]["name"] + additional_metadata["electrode"] = electrodes[electrode_name] + if timeseries_class != TimeSeries: + conversion, units = get_units_conversion(signal, timeseries_class) + additional_metadata["conversion"] = conversion + else: + units = signal.units + if isinstance(signal, AnalogSignal): + sampling_rate = signal.sampling_rate.rescale("Hz") + tS = timeseries_class( + name=signal.name, + starting_time=time_in_seconds(signal.t_start), + data=signal, + unit=units.dimensionality.string, + rate=float(sampling_rate), + comments=json.dumps(hierarchy), + **additional_metadata) + # todo: try to add array_annotations via "control" attribute + elif isinstance(signal, IrregularlySampledSignal): + tS = timeseries_class( + name=signal.name, + data=signal, + unit=units.dimensionality.string, + timestamps=signal.times.rescale('second').magnitude, + comments=json.dumps(hierarchy), + **additional_metadata) + else: + raise TypeError( + "signal has type {0}, should be AnalogSignal or IrregularlySampledSignal".format( + signal.__class__.__name__)) + nwb_group = signal.annotations.get("nwb_group", "acquisition") + add_method_map = { + "acquisition": nwbfile.add_acquisition, + "stimulus": nwbfile.add_stimulus + } + if nwb_group in add_method_map: + add_time_series = add_method_map[nwb_group] + else: + raise NotImplementedError("NWB group '{}' not yet supported".format(nwb_group)) + add_time_series(tS) + return tS + + def _write_spiketrain(self, nwbfile, spiketrain): + nwbfile.add_unit(spike_times=spiketrain.rescale('s').magnitude, + obs_intervals=[[float(spiketrain.t_start.rescale('s')), + float(spiketrain.t_stop.rescale('s'))]], + _name=spiketrain.name, + # _description=spiketrain.description, + segment=spiketrain.segment.name, + block=spiketrain.segment.block.name) + # todo: handle annotations (using add_unit_column()?) + # todo: handle Neo Units + # todo: handle spike waveforms, if any (see SpikeEventSeries) + return nwbfile.units + + def _write_event(self, nwbfile, event): + hierarchy = {'block': event.segment.block.name, 'segment': event.segment.name} + tS_evt = AnnotationSeries( + name=event.name, + data=event.labels, + timestamps=event.times.rescale('second').magnitude, + description=event.description or "", + comments=json.dumps(hierarchy)) + nwbfile.add_acquisition(tS_evt) + return tS_evt + + def _write_epoch(self, nwbfile, epoch): + for t_start, duration, label in zip(epoch.rescale('s').magnitude, + epoch.durations.rescale('s').magnitude, + epoch.labels): + nwbfile.add_epoch(t_start, t_start + duration, [label], [], + _name=epoch.name, + segment=epoch.segment.name, + block=epoch.segment.block.name) + return nwbfile.epochs + + +class AnalogSignalProxy(BaseAnalogSignalProxy): + common_metadata_fields = ( + # fields that are the same for all TimeSeries subclasses + "comments", "description", "unit", "starting_time", "timestamps", "rate", + "data", "starting_time_unit", "timestamps_unit", "electrode" + ) + + def __init__(self, timeseries, nwb_group): + self._timeseries = timeseries + self.units = timeseries.unit + if timeseries.conversion: + self.units = _recompose_unit(timeseries.unit, timeseries.conversion) + if timeseries.starting_time is not None: + self.t_start = timeseries.starting_time * pq.s + else: + self.t_start = timeseries.timestamps[0] * pq.s + if timeseries.rate: + self.sampling_rate = timeseries.rate * pq.Hz + else: + self.sampling_rate = None + self.name = timeseries.name + self.annotations = {"nwb_group": nwb_group} + self.description = try_json_field(timeseries.description) + if isinstance(self.description, dict): + self.annotations["notes"] = self.description + if "name" in self.annotations: + self.annotations.pop("name") + self.description = None + self.shape = self._timeseries.data.shape + if len(self.shape) == 1: + self.shape = (self.shape[0], 1) + metadata_fields = list(timeseries.__nwbfields__) + for field_name in self.__class__.common_metadata_fields: # already handled + try: + metadata_fields.remove(field_name) + except ValueError: + pass + for field_name in metadata_fields: + value = getattr(timeseries, field_name) + if value is not None: + self.annotations[f"nwb:{field_name}"] = value + self.annotations["nwb_neurodata_type"] = ( + timeseries.__class__.__module__, + timeseries.__class__.__name__ + ) + if hasattr(timeseries, "electrode"): + # todo: once the Group class is available, we could add electrode metadata + # to a Group containing all signals that share that electrode + # This would reduce the amount of redundancy (repeated metadata in every signal) + electrode_metadata = {"device": {}} + metadata_fields = list(timeseries.electrode.__class__.__nwbfields__) + ["name"] + metadata_fields.remove("device") # needs special handling + for field_name in metadata_fields: + value = getattr(timeseries.electrode, field_name) + if value is not None: + electrode_metadata[field_name] = value + for field_name in timeseries.electrode.device.__class__.__nwbfields__: + value = getattr(timeseries.electrode.device, field_name) + if value is not None: + electrode_metadata["device"][field_name] = value + self.annotations["nwb_electrode"] = electrode_metadata + + def load(self, time_slice=None, strict_slicing=True): + """ + Load AnalogSignalProxy args: + :param time_slice: None or tuple of the time slice expressed with quantities. + None is the entire signal. + :param strict_slicing: True by default. + Control if an error is raised or not when one of the time_slice members + (t_start or t_stop) is outside the real time range of the segment. + """ + i_start, i_stop, sig_t_start = None, None, self.t_start + if time_slice: + if self.sampling_rate is None: + i_start, i_stop = np.searchsorted(self._timeseries.timestamps, time_slice) + else: + i_start, i_stop, sig_t_start = self._time_slice_indices( + time_slice, strict_slicing=strict_slicing) + signal = self._timeseries.data[i_start: i_stop] + if self.sampling_rate is None: + return IrregularlySampledSignal( + self._timeseries.timestamps[i_start:i_stop] * pq.s, + signal, + units=self.units, + t_start=sig_t_start, + sampling_rate=self.sampling_rate, + name=self.name, + description=self.description, + array_annotations=None, + **self.annotations) # todo: timeseries.control / control_description + + else: + return AnalogSignal( + signal, + units=self.units, + t_start=sig_t_start, + sampling_rate=self.sampling_rate, + name=self.name, + description=self.description, + array_annotations=None, + **self.annotations) # todo: timeseries.control / control_description + + +class EventProxy(BaseEventProxy): + + def __init__(self, timeseries, nwb_group): + self._timeseries = timeseries + self.name = timeseries.name + self.annotations = {"nwb_group": nwb_group} + self.description = try_json_field(timeseries.description) + if isinstance(self.description, dict): + self.annotations.update(self.description) + self.description = None + self.shape = self._timeseries.data.shape + + def load(self, time_slice=None, strict_slicing=True): + """ + Load EventProxy args: + :param time_slice: None or tuple of the time slice expressed with quantities. + None is the entire signal. + :param strict_slicing: True by default. + Control if an error is raised or not when one of the time_slice members + (t_start or t_stop) is outside the real time range of the segment. + """ + if time_slice: + raise NotImplementedError("todo") + else: + times = self._timeseries.timestamps[:] + labels = self._timeseries.data[:] + return Event(times * pq.s, + labels=labels, + name=self.name, + description=self.description, + **self.annotations) + + +class EpochProxy(BaseEpochProxy): + + def __init__(self, time_intervals, epoch_name=None, index=None): + """ + :param time_intervals: An epochs table, + which is a specific TimeIntervals table that stores info about long periods + :param epoch_name: (str) + Name of the epoch object + :param index: (np.array, slice) + Slice object or array of bool values masking time_intervals to be used. In case of + an array it has to have the same shape as `time_intervals`. + """ + self._time_intervals = time_intervals + if index is not None: + self._index = index + self.shape = (index.sum(),) + else: + self._index = slice(None) + self.shape = (len(time_intervals),) + self.name = epoch_name + + def load(self, time_slice=None, strict_slicing=True): + """ + Load EpochProxy args: + :param time_slice: None or tuple of the time slice expressed with quantities. + None is all of the intervals. + :param strict_slicing: True by default. + Control if an error is raised or not when one of the time_slice members + (t_start or t_stop) is outside the real time range of the segment. + """ + if time_slice: + raise NotImplementedError("todo") + else: + start_times = self._time_intervals.start_time[self._index] + stop_times = self._time_intervals.stop_time[self._index] + durations = stop_times - start_times + labels = self._time_intervals.tags[self._index] + + return Epoch(times=start_times * pq.s, + durations=durations * pq.s, + labels=labels, + name=self.name) + + +class SpikeTrainProxy(BaseSpikeTrainProxy): + + def __init__(self, units_table, id): + """ + :param units_table: A Units table + (see https://pynwb.readthedocs.io/en/stable/pynwb.misc.html#pynwb.misc.Units) + :param id: the cell/unit ID (integer) + """ + self._units_table = units_table + self.id = id + self.units = pq.s + obs_intervals = units_table.get_unit_obs_intervals(id) + if len(obs_intervals) == 0: + t_start, t_stop = None, None + elif len(obs_intervals) == 1: + t_start, t_stop = obs_intervals[0] + else: + raise NotImplementedError("Can't yet handle multiple observation intervals") + self.t_start = t_start * pq.s + self.t_stop = t_stop * pq.s + self.annotations = {"nwb_group": "acquisition"} + try: + # NWB files created by Neo store the name as an extra column + self.name = units_table._name[id] + except AttributeError: + self.name = None + self.shape = None # no way to get this without reading the data + + def load(self, time_slice=None, strict_slicing=True): + """ + Load SpikeTrainProxy args: + :param time_slice: None or tuple of the time slice expressed with quantities. + None is the entire spike train. + :param strict_slicing: True by default. + Control if an error is raised or not when one of the time_slice members + (t_start or t_stop) is outside the real time range of the segment. + """ + interval = None + if time_slice: + interval = (float(t) for t in time_slice) # convert from quantities + spike_times = self._units_table.get_unit_spike_times(self.id, in_interval=interval) + return SpikeTrain( + spike_times * self.units, + self.t_stop, + units=self.units, + # sampling_rate=array(1.) * Hz, + t_start=self.t_start, + # waveforms=None, + # left_sweep=None, + name=self.name, + # file_origin=None, + # description=None, + # array_annotations=None, + **self.annotations) diff --git a/neo/io/proxyobjects.py b/neo/io/proxyobjects.py index fd050abf5..325f1cd54 100644 --- a/neo/io/proxyobjects.py +++ b/neo/io/proxyobjects.py @@ -166,6 +166,44 @@ def t_stop(self): '''Time when signal ends''' return self.t_start + self.duration + def _time_slice_indices(self, time_slice, strict_slicing=True): + """ + Calculate the start and end indices for the slice. + + Also returns t_start + """ + if time_slice is None: + i_start, i_stop = None, None + sig_t_start = self.t_start + else: + sr = self.sampling_rate + t_start, t_stop = time_slice + if t_start is None: + i_start = None + sig_t_start = self.t_start + else: + t_start = ensure_second(t_start) + if strict_slicing: + assert self.t_start <= t_start <= self.t_stop, 't_start is outside' + else: + t_start = max(t_start, self.t_start) + # the i_start is necessary ceil + i_start = int(np.ceil((t_start - self.t_start).magnitude * sr.magnitude)) + # this needed to get the real t_start of the first sample + # because do not necessary match what is demanded + sig_t_start = self.t_start + i_start / sr + + if t_stop is None: + i_stop = None + else: + t_stop = ensure_second(t_stop) + if strict_slicing: + assert self.t_start <= t_stop <= self.t_stop, 't_stop is outside' + else: + t_stop = min(t_stop, self.t_stop) + i_stop = int((t_stop - self.t_start).magnitude * sr.magnitude) + return i_start, i_stop, sig_t_start + def load(self, time_slice=None, strict_slicing=True, channel_indexes=None, magnitude_mode='rescaled'): ''' @@ -210,37 +248,8 @@ def load(self, time_slice=None, strict_slicing=True, else: fixed_chan_indexes = self._inner_stream_channels[channel_indexes] - sr = self.sampling_rate - - if time_slice is None: - i_start, i_stop = None, None - sig_t_start = self.t_start - else: - t_start, t_stop = time_slice - if t_start is None: - i_start = None - sig_t_start = self.t_start - else: - t_start = ensure_second(t_start) - if strict_slicing: - assert self.t_start <= t_start <= self.t_stop, 't_start is outside' - else: - t_start = max(t_start, self.t_start) - # the i_start is ncessary ceil - i_start = int(np.ceil((t_start - self.t_start).magnitude * sr.magnitude)) - # this needed to get the real t_start of the first sample - # because do not necessary match what is demanded - sig_t_start = self.t_start + i_start / sr - - if t_stop is None: - i_stop = None - else: - t_stop = ensure_second(t_stop) - if strict_slicing: - assert self.t_start <= t_stop <= self.t_stop, 't_stop is outside' - else: - t_stop = min(t_stop, self.t_stop) - i_stop = int((t_stop - self.t_start).magnitude * sr.magnitude) + i_start, i_stop, sig_t_start = self._time_slice_indices(time_slice, + strict_slicing=strict_slicing) raw_signal = self._rawio.get_analogsignal_chunk(block_index=self._block_index, seg_index=self._seg_index, i_start=i_start, i_stop=i_stop, diff --git a/neo/test/iotest/test_nwbio.py b/neo/test/iotest/test_nwbio.py new file mode 100644 index 000000000..5e8bc1f2c --- /dev/null +++ b/neo/test/iotest/test_nwbio.py @@ -0,0 +1,254 @@ +# +""" +Tests of neo.io.nwbio +""" + +from __future__ import unicode_literals, print_function, division, absolute_import + +import os +import unittest +from datetime import datetime + +try: + from urllib.request import urlretrieve +except ImportError: + from urllib import urlretrieve +from neo.test.iotest.common_io_test import BaseTestIO +from neo.core import AnalogSignal, SpikeTrain, Event, Epoch, IrregularlySampledSignal, Segment, \ + Block + +try: + import pynwb + from neo.io.nwbio import NWBIO + + HAVE_PYNWB = True +except (ImportError, SyntaxError): + NWBIO = None + HAVE_PYNWB = False +import quantities as pq +import numpy as np +from numpy.testing import assert_array_equal, assert_allclose + + +@unittest.skipUnless(HAVE_PYNWB, "requires pynwb") +class TestNWBIO(BaseTestIO, unittest.TestCase): + ioclass = NWBIO + entities_to_download = ["nwb"] + entities_to_test = [ + # Files from Allen Institute: + "nwb/H19.29.141.11.21.01.nwb", # 7 MB + ] + + def test_roundtrip(self): + + annotations = { + "session_start_time": datetime.now() + } + # Define Neo blocks + bl0 = Block(name='First block', **annotations) + bl1 = Block(name='Second block', **annotations) + bl2 = Block(name='Third block', **annotations) + original_blocks = [bl0, bl1, bl2] + + num_seg = 4 # number of segments + num_chan = 3 # number of channels + + for blk in original_blocks: + + for ind in range(num_seg): # number of Segments + seg = Segment(index=ind) + seg.block = blk + blk.segments.append(seg) + + for seg in blk.segments: # AnalogSignal objects + + # 3 Neo AnalogSignals + a = AnalogSignal(np.random.randn(44, num_chan) * pq.nA, + sampling_rate=10 * pq.kHz, + t_start=50 * pq.ms) + b = AnalogSignal(np.random.randn(64, num_chan) * pq.mV, + sampling_rate=8 * pq.kHz, + t_start=40 * pq.ms) + c = AnalogSignal(np.random.randn(33, num_chan) * pq.uA, + sampling_rate=10 * pq.kHz, + t_start=120 * pq.ms) + + # 2 Neo IrregularlySampledSignals + d = IrregularlySampledSignal(np.arange(7.0) * pq.ms, + np.random.randn(7, num_chan) * pq.mV) + + # 2 Neo SpikeTrains + train = SpikeTrain(times=[1, 2, 3] * pq.s, t_start=1.0, t_stop=10.0) + train2 = SpikeTrain(times=[4, 5, 6] * pq.s, t_stop=10.0) + # todo: add waveforms + + # 1 Neo Event + evt = Event(times=np.arange(0, 30, 10) * pq.ms, + labels=np.array(['ev0', 'ev1', 'ev2'])) + + # 2 Neo Epochs + epc = Epoch(times=np.arange(0, 30, 10) * pq.s, + durations=[10, 5, 7] * pq.ms, + labels=np.array(['btn0', 'btn1', 'btn2'])) + + epc2 = Epoch(times=np.arange(10, 40, 10) * pq.s, + durations=[9, 3, 8] * pq.ms, + labels=np.array(['btn3', 'btn4', 'btn5'])) + + seg.spiketrains.append(train) + seg.spiketrains.append(train2) + + seg.epochs.append(epc) + seg.epochs.append(epc2) + + seg.analogsignals.append(a) + seg.analogsignals.append(b) + seg.analogsignals.append(c) + seg.irregularlysampledsignals.append(d) + seg.events.append(evt) + a.segment = seg + b.segment = seg + c.segment = seg + d.segment = seg + evt.segment = seg + train.segment = seg + train2.segment = seg + epc.segment = seg + epc2.segment = seg + + # write to file + test_file_name = "test_round_trip.nwb" + iow = NWBIO(filename=test_file_name, mode='w') + iow.write_all_blocks(original_blocks) + + ior = NWBIO(filename=test_file_name, mode='r') + retrieved_blocks = ior.read_all_blocks() + + self.assertEqual(len(retrieved_blocks), 3) + self.assertEqual(len(retrieved_blocks[2].segments), num_seg) + + original_signal_22b = original_blocks[2].segments[2].analogsignals[1] + retrieved_signal_22b = retrieved_blocks[2].segments[2].analogsignals[1] + for attr_name in ("name", "units", "sampling_rate", "t_start"): + retrieved_attribute = getattr(retrieved_signal_22b, attr_name) + original_attribute = getattr(original_signal_22b, attr_name) + self.assertEqual(retrieved_attribute, original_attribute) + assert_array_equal(retrieved_signal_22b.magnitude, original_signal_22b.magnitude) + + original_issignal_22d = original_blocks[2].segments[2].irregularlysampledsignals[0] + retrieved_issignal_22d = retrieved_blocks[2].segments[2].irregularlysampledsignals[0] + for attr_name in ("name", "units", "t_start"): + retrieved_attribute = getattr(retrieved_issignal_22d, attr_name) + original_attribute = getattr(original_issignal_22d, attr_name) + self.assertEqual(retrieved_attribute, original_attribute) + assert_array_equal(retrieved_issignal_22d.times.rescale('ms').magnitude, + original_issignal_22d.times.rescale('ms').magnitude) + assert_array_equal(retrieved_issignal_22d.magnitude, original_issignal_22d.magnitude) + + original_event_11 = original_blocks[1].segments[1].events[0] + retrieved_event_11 = retrieved_blocks[1].segments[1].events[0] + for attr_name in ("name",): + retrieved_attribute = getattr(retrieved_event_11, attr_name) + original_attribute = getattr(original_event_11, attr_name) + self.assertEqual(retrieved_attribute, original_attribute) + assert_array_equal(retrieved_event_11.rescale('ms').magnitude, + original_event_11.rescale('ms').magnitude) + assert_array_equal(retrieved_event_11.labels, original_event_11.labels) + + original_spiketrain_131 = original_blocks[1].segments[1].spiketrains[1] + retrieved_spiketrain_131 = retrieved_blocks[1].segments[1].spiketrains[1] + for attr_name in ("name", "t_start", "t_stop"): + retrieved_attribute = getattr(retrieved_spiketrain_131, attr_name) + original_attribute = getattr(original_spiketrain_131, attr_name) + self.assertEqual(retrieved_attribute, original_attribute) + assert_array_equal(retrieved_spiketrain_131.times.rescale('ms').magnitude, + original_spiketrain_131.times.rescale('ms').magnitude) + + original_epoch_11 = original_blocks[1].segments[1].epochs[0] + retrieved_epoch_11 = retrieved_blocks[1].segments[1].epochs[0] + for attr_name in ("name",): + retrieved_attribute = getattr(retrieved_epoch_11, attr_name) + original_attribute = getattr(original_epoch_11, attr_name) + self.assertEqual(retrieved_attribute, original_attribute) + assert_array_equal(retrieved_epoch_11.rescale('ms').magnitude, + original_epoch_11.rescale('ms').magnitude) + assert_allclose(retrieved_epoch_11.durations.rescale('ms').magnitude, + original_epoch_11.durations.rescale('ms').magnitude) + assert_array_equal(retrieved_epoch_11.labels, original_epoch_11.labels) + os.remove(test_file_name) + + def test_roundtrip_with_annotations(self): + # test with NWB-specific annotations + + original_block = Block(name="experiment", session_start_time=datetime.now()) + segment = Segment(name="session 1") + original_block.segments.append(segment) + segment.block = original_block + + electrode_annotations = { + "name": "electrode #1", + "description": "intracellular electrode", + "device": { + "name": "electrode #1" + } + } + stimulus_annotations = { + "nwb_group": "stimulus", + "nwb_neurodata_type": ("pynwb.icephys", "CurrentClampStimulusSeries"), + "nwb_electrode": electrode_annotations, + "nwb:sweep_number": 1, + "nwb:gain": 1.0 + } + response_annotations = { + "nwb_group": "acquisition", + "nwb_neurodata_type": ("pynwb.icephys", "CurrentClampSeries"), + "nwb_electrode": electrode_annotations, + "nwb:sweep_number": 1, + "nwb:gain": 1.0, + "nwb:bias_current": 1e-12, + "nwb:bridge_balance": 70e6, + "nwb:capacitance_compensation": 1e-12 + } + stimulus = AnalogSignal(np.random.randn(100, 1) * pq.nA, + sampling_rate=5 * pq.kHz, + t_start=50 * pq.ms, + name="stimulus", + **stimulus_annotations) + response = AnalogSignal(np.random.randn(100, 1) * pq.mV, + sampling_rate=5 * pq.kHz, + t_start=50 * pq.ms, + name="response", + **response_annotations) + segment.analogsignals = [stimulus, response] + stimulus.segment = response.segment = segment + + test_file_name = "test_round_trip_with_annotations.nwb" + iow = NWBIO(filename=test_file_name, mode='w') + iow.write_all_blocks([original_block]) + + nwbfile = pynwb.NWBHDF5IO(test_file_name, mode="r").read() + + self.assertIsInstance(nwbfile.acquisition["response"], pynwb.icephys.CurrentClampSeries) + self.assertIsInstance(nwbfile.stimulus["stimulus"], + pynwb.icephys.CurrentClampStimulusSeries) + self.assertEqual(nwbfile.acquisition["response"].bridge_balance, + response_annotations["nwb:bridge_balance"]) + + ior = NWBIO(filename=test_file_name, mode='r') + retrieved_block = ior.read_all_blocks()[0] + + original_response = original_block.segments[0].filter(name="response")[0] + retrieved_response = retrieved_block.segments[0].filter(name="response")[0] + for attr_name in ("name", "units", "sampling_rate", "t_start"): + retrieved_attribute = getattr(retrieved_response, attr_name) + original_attribute = getattr(original_response, attr_name) + self.assertEqual(retrieved_attribute, original_attribute) + assert_array_equal(retrieved_response.magnitude, original_response.magnitude) + + os.remove(test_file_name) + + +if __name__ == "__main__": + if HAVE_PYNWB: + print("pynwb.__version__ = ", pynwb.__version__) + unittest.main() diff --git a/requirements_testing.txt b/requirements_testing.txt index 24536ae8b..75002ef73 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -13,3 +13,4 @@ coverage coveralls pillow sonpy +pynwb