diff --git a/neo/io/nwbio.py b/neo/io/nwbio.py index ca895ff4b..32c55bc93 100644 --- a/neo/io/nwbio.py +++ b/neo/io/nwbio.py @@ -45,6 +45,8 @@ from pynwb.misc import AnnotationSeries from pynwb import image from pynwb.image import ImageSeries + from pynwb.file import Subject + from pynwb.epoch import TimeIntervals from pynwb.spec import NWBAttributeSpec, NWBDatasetSpec, NWBGroupSpec, NWBNamespace, \ NWBNamespaceBuilder from pynwb.device import Device @@ -55,6 +57,13 @@ except ImportError: have_pynwb = False +try: + import nwbinspector + from nwbinspector import inspect_nwb, check_regular_timestamps + have_nwbinspector = True +except ImportError: + have_nwbinspector = False + # hdmf imports try: from hdmf.spec import (LinkSpec, GroupSpec, DatasetSpec, SpecNamespace, @@ -244,6 +253,8 @@ def __init__(self, filename, mode='r'): raise Exception("Please install the pynwb package to use NWBIO") if not have_hdmf: raise Exception("Please install the hdmf package to use NWBIO") + if not have_nwbinspector: + raise Exception("Please install the nwbinspector package to use NWBIO") BaseIO.__init__(self, filename=filename) self.filename = filename self.blocks_written = 0 @@ -275,6 +286,9 @@ def read_all_blocks(self, lazy=False, **kwargs): if "file_create_date" in self.global_block_metadata: self.global_block_metadata["file_datetime"] = self.global_block_metadata[ "rec_datetime"] + if "subject" in self.global_block_metadata: + self.global_block_metadata["subject"] = self.global_block_metadata[ + "subject"] self._blocks = {} self._read_acquisition_group(lazy=lazy) @@ -352,8 +366,6 @@ def _read_timeseries_group(self, group_name, lazy): except JSONDecodeError: # For NWB files created with other applications, we put everything in a single # segment in a single block - # todo: investigate whether there is a reliable way to create multiple segments, - # e.g. using Trial information block_name = "default" segment_name = "default" else: @@ -441,8 +453,24 @@ def write_all_blocks(self, blocks, **kwargs): raise Exception("Writing to NWB requires an annotation 'session_start_time'") self.annotations = {"rec_datetime": "rec_datetime"} self.annotations["rec_datetime"] = blocks[0].rec_datetime - # todo: handle subject nwbfile = NWBFile(**annotations) + if "subject" not in annotations: + # create neo dummy subject + # All following arguments are decided by this IO and are free + read_neo_dummy_params = { + Subject: [ + ("subject_id", {"value": "subject_id", "label": "empty_neo_subject_id"}), + ("age", {"value": "P0D", "label": "Period x days old"}), + ("description", {"value": "no description", "label": "Description"}), + ("species", {"value": "Mus musculus", "label": "Species by default"}), + ("sex", {"value": "U", "label": "Sex unknown"}), + ], + } + nwbfile.subject = Subject(subject_id="subject_id", + age="P0D", # Period x days old + description="no description", + species="Mus musculus", # by default + sex="U") assert self.nwb_file_mode in ('w',) # possibly expand to 'a'ppend later if self.nwb_file_mode == "w" and os.path.exists(self.filename): os.remove(self.filename) @@ -465,8 +493,13 @@ def write_all_blocks(self, blocks, **kwargs): nwbfile.add_epoch_column('block', 'the name of the Neo Block to which the Epoch belongs') + arr = [[], []] # epoch array for ascending t_start and t_stop for i, block in enumerate(blocks): - self.write_block(nwbfile, block) + block_name = block.name + self.write_block(nwbfile, block, arr) + arr2 = np.sort(arr) + self._write_epoch(nwbfile, block, arr2) + io_nwb.write(nwbfile) io_nwb.close() @@ -475,7 +508,37 @@ def write_all_blocks(self, blocks, **kwargs): if errors: raise Exception(f"Errors found when validating {self.filename}") - def write_block(self, nwbfile, block, **kwargs): + # NWBInspector : Inspect NWB files for compliance with NWB Best Practices. + results_generator = inspect_nwb(nwbfile_path=self.filename) + for message in results_generator: + if message.importance._name_ == "CRITICAL": + print("message.importance = ", message.importance) + print("Potentially incorrect data") + print(message.message) + print("message.check_function_name = ", message.check_function_name) + print("message.object_type = ", message.object_type) + print("message.object_name = ", message.object_name) + print("----------------------") + if message.importance._name_ == "BEST_PRACTICE_VIOLATION": + print("message.importance = ", message.importance) + print("Very suboptimal data representation") + print(message.message) + print("message.check_function_name = ", message.check_function_name) + print("message.object_type = ", message.object_type) + print("message.object_name = ", message.object_name) + print("----------------------") + if message.importance._name_ == "BEST_PRACTICE_SUGGESTION": + print("message.importance = ", message.importance) + print("Improvable data representation") + print(message.message) + print("message.check_function_name = ", message.check_function_name) + print("message.object_type = ", message.object_type) + print("message.object_name = ", message.object_name) + print("----------------------") + + io_nwb.close() + + def write_block(self, nwbfile, block, arr, **kwargs): """ Write a Block to the file :param block: Block to be written @@ -485,10 +548,11 @@ def write_block(self, nwbfile, block, **kwargs): if not block.name: block.name = "block%d" % self.blocks_written for i, segment in enumerate(block.segments): + segment.name = "%s : segment%d" % (block.name, i) assert segment.block is block if not segment.name: segment.name = "%s : segment%d" % (block.name, i) - self._write_segment(nwbfile, segment, electrodes) + self._write_segment(nwbfile, segment, electrodes, arr) self.blocks_written += 1 def _write_electrodes(self, nwbfile, block): @@ -512,8 +576,7 @@ def _write_electrodes(self, nwbfile, block): ) return electrodes - def _write_segment(self, nwbfile, segment, electrodes): - # maybe use NWB trials to store Segment metadata? + def _write_segment(self, nwbfile, segment, electrodes, arr): for i, signal in enumerate( chain(segment.analogsignals, segment.irregularlysampledsignals)): assert signal.segment is segment @@ -541,8 +604,8 @@ def _write_segment(self, nwbfile, segment, electrodes): for i, epoch in enumerate(segment.epochs): if not epoch.name: - epoch.name = "%s : epoch%d" % (segment.name, i) - self._write_epoch(nwbfile, epoch) + epoch_name = "%s : epoch%d" % (segment.name, i) + self._write_manage_epoch(nwbfile, segment, epoch, arr) def _write_signal(self, nwbfile, signal, electrodes): hierarchy = {'block': signal.segment.block.name, 'segment': signal.segment.name} @@ -566,23 +629,21 @@ def _write_signal(self, nwbfile, signal, electrodes): signal = signal.load() if isinstance(signal, AnalogSignal): sampling_rate = signal.sampling_rate.rescale("Hz") - tS = timeseries_class( - name=signal.name, - starting_time=time_in_seconds(signal.t_start), - data=signal, - unit=units.dimensionality.string, - rate=float(sampling_rate), - comments=json.dumps(hierarchy), - **additional_metadata) + tS = timeseries_class(name=signal.name, + starting_time=time_in_seconds(signal.t_start), + data=signal, + unit=units.dimensionality.string, + rate=float(sampling_rate), + comments=json.dumps(hierarchy), + **additional_metadata) # todo: try to add array_annotations via "control" attribute elif isinstance(signal, IrregularlySampledSignal): - tS = timeseries_class( - name=signal.name, - data=signal, - unit=units.dimensionality.string, - timestamps=signal.times.rescale('second').magnitude, - comments=json.dumps(hierarchy), - **additional_metadata) + tS = timeseries_class(name=signal.name, + data=signal, + unit=units.dimensionality.string, + timestamps=signal.times.rescale('second').magnitude, + comments=json.dumps(hierarchy), + **additional_metadata) else: raise TypeError( "signal has type {0}, should be AnalogSignal or IrregularlySampledSignal".format( @@ -620,26 +681,56 @@ def _write_event(self, nwbfile, event): if hasattr(event, 'proxy_for') and event.proxy_for == Event: event = event.load() hierarchy = {'block': segment.block.name, 'segment': segment.name} - tS_evt = AnnotationSeries( - name=event.name, - data=event.labels, - timestamps=event.times.rescale('second').magnitude, - description=event.description or "", - comments=json.dumps(hierarchy)) + # if constant timestamps + timestamps = event.times.rescale('second').magnitude + if any(timestamps) == any(timestamps): + tS_evt = TimeSeries(name=event.name, + data=event.labels, + starting_time=0.0, + rate=0.01, + unit=str(event.units), + description=event.description or "", + comments=json.dumps(hierarchy)) + else: + tS_evt = TimeSeries(name=event.name, + data=event.labels, + timestamps=event.times.rescale('second').magnitude, + unit=str(event.units), + description=event.description or "", + comments=json.dumps(hierarchy)) + nwbfile.add_acquisition(tS_evt) return tS_evt - def _write_epoch(self, nwbfile, epoch): - segment = epoch.segment + def _write_manage_epoch(self, nwbfile, segment, epoch, arr): if hasattr(epoch, 'proxy_for') and epoch.proxy_for == Epoch: epoch = epoch.load() for t_start, duration, label in zip(epoch.rescale('s').magnitude, epoch.durations.rescale('s').magnitude, - epoch.labels): - nwbfile.add_epoch(t_start, t_start + duration, [label], [], - _name=epoch.name, - segment=segment.name, - block=segment.block.name) + epoch.labels, + ): + for j in [label]: + t_stop = t_start + duration + seg_name = "%s %s" % (epoch.segment.name, label) + bl_name = "%s %s" % (epoch.segment.block.name, label) + epoch_name = "%s %s" % (segment.name, j) + + arr[0].append(t_start) + arr[1].append(t_stop) + + def _write_epoch(self, nwbfile, block, arr2): + for i in range(len(arr2[0])): + t_start = arr2[0][i] + t_stop = arr2[1][i] + for k in block.segments: + segment_name = k.name + nwbfile.add_epoch(start_time=t_start, + stop_time=t_stop, + tags=[" "], + timeseries=[], + _name=k.name, + segment=segment_name, + block=block.name) return nwbfile.epochs @@ -656,6 +747,9 @@ def __init__(self, timeseries, nwb_group): self.units = timeseries.unit if timeseries.conversion: self.units = _recompose_unit(timeseries.unit, timeseries.conversion) + check_timestamps = check_regular_timestamps(timeseries) + if check_timestamps is not None: + timeseries.starting_time = 0.0 if timeseries.starting_time is not None: self.t_start = timeseries.starting_time * pq.s else: @@ -724,27 +818,27 @@ def load(self, time_slice=None, strict_slicing=True): time_slice, strict_slicing=strict_slicing) signal = self._timeseries.data[i_start: i_stop] if self.sampling_rate is None: - return IrregularlySampledSignal( - self._timeseries.timestamps[i_start:i_stop] * pq.s, - signal, - units=self.units, - t_start=sig_t_start, - sampling_rate=self.sampling_rate, - name=self.name, - description=self.description, - array_annotations=None, - **self.annotations) # todo: timeseries.control / control_description + return IrregularlySampledSignal(self._timeseries.timestamps[i_start:i_stop] * pq.s, + signal=signal, + units=self.units, + t_start=sig_t_start, + sampling_rate=self.sampling_rate, + name=self.name, + description=self.description, + array_annotations=None, + **self.annotations) + # todo: timeseries.control / control_description else: - return AnalogSignal( - signal, - units=self.units, - t_start=sig_t_start, - sampling_rate=self.sampling_rate, - name=self.name, - description=self.description, - array_annotations=None, - **self.annotations) # todo: timeseries.control / control_description + return AnalogSignal(signal, + units=self.units, + t_start=sig_t_start, + sampling_rate=self.sampling_rate, + name=self.name, + description=self.description, + array_annotations=None, + **self.annotations) + # todo: timeseries.control / control_description class EventProxy(BaseEventProxy): diff --git a/neo/test/iotest/test_nwbio.py b/neo/test/iotest/test_nwbio.py index 95e5ec4de..61739e883 100644 --- a/neo/test/iotest/test_nwbio.py +++ b/neo/test/iotest/test_nwbio.py @@ -31,6 +31,9 @@ import quantities as pq import numpy as np from numpy.testing import assert_array_equal, assert_allclose +import nwbinspector +from nwbinspector import inspect_nwb +from pynwb.file import Subject @unittest.skipUnless(HAVE_PYNWB, "requires pynwb") @@ -44,26 +47,44 @@ class TestNWBIO(BaseTestIO, unittest.TestCase): def test_roundtrip(self): + subject_annotations = {"nwb:subject_id": "012", + "nwb:age": "P90D", + "nwb:description": "mouse 5", + "nwb:species": "Mus musculus", + "nwb:sex": "M"} annotations = { - "session_start_time": datetime.now() + "session_start_time": datetime.now(), + "subject": subject_annotations, } # Define Neo blocks - bl0 = Block(name='First block', **annotations) - bl1 = Block(name='Second block', **annotations) - bl2 = Block(name='Third block', **annotations) + bl0 = Block(name='First block', + experimenter="Experimenter's name", + experiment_description="Experiment description", + institution="Institution", + **annotations) + bl1 = Block(name='Second block', + experimenter="Experimenter's name", + experiment_description="Experiment description", + institution="Institution", + **annotations) + bl2 = Block(name='Third block', + experimenter="Experimenter's name", + experiment_description="Experiment description", + institution="Institution", + **annotations) original_blocks = [bl0, bl1, bl2] num_seg = 4 # number of segments - num_chan = 3 # number of channels + num_chan = 6 # number of channels - for blk in original_blocks: + for j, blk in enumerate(original_blocks): for ind in range(num_seg): # number of Segments seg = Segment(index=ind) seg.block = blk blk.segments.append(seg) - for seg in blk.segments: # AnalogSignal objects + for i, seg in enumerate(blk.segments): # AnalogSignal objects # 3 Neo AnalogSignals a = AnalogSignal(name='Signal_a %s' % (seg.name), @@ -78,49 +99,26 @@ def test_roundtrip(self): signal=np.random.randn(33, num_chan) * pq.uA, sampling_rate=10 * pq.kHz, t_start=120 * pq.ms) - # 2 Neo IrregularlySampledSignals - d = IrregularlySampledSignal(np.arange(7.0) * pq.ms, - np.random.randn(7, num_chan) * pq.mV) - + # 1 Neo IrregularlySampledSignals + d = IrregularlySampledSignal([0.01, 0.03, 0.12] * pq.s, + [[4, 5], [5, 4], [6, 3]] * pq.nA) # 2 Neo SpikeTrains train = SpikeTrain(times=[1, 2, 3] * pq.s, t_start=1.0, t_stop=10.0) train2 = SpikeTrain(times=[4, 5, 6] * pq.s, t_stop=10.0) # todo: add waveforms - # 1 Neo Event - evt = Event(name='Event', - times=np.arange(0, 30, 10) * pq.ms, - labels=np.array(['ev0', 'ev1', 'ev2'])) - - # 2 Neo Epochs - epc = Epoch(times=np.arange(0, 30, 10) * pq.s, - durations=[10, 5, 7] * pq.ms, - labels=np.array(['btn0', 'btn1', 'btn2'])) - - epc2 = Epoch(times=np.arange(10, 40, 10) * pq.s, - durations=[9, 3, 8] * pq.ms, - labels=np.array(['btn3', 'btn4', 'btn5'])) - seg.spiketrains.append(train) seg.spiketrains.append(train2) - - seg.epochs.append(epc) - seg.epochs.append(epc2) - seg.analogsignals.append(a) seg.analogsignals.append(b) seg.analogsignals.append(c) seg.irregularlysampledsignals.append(d) - seg.events.append(evt) a.segment = seg b.segment = seg c.segment = seg d.segment = seg - evt.segment = seg train.segment = seg train2.segment = seg - epc.segment = seg - epc2.segment = seg # write to file test_file_name = "test_round_trip.nwb" @@ -151,16 +149,6 @@ def test_roundtrip(self): original_issignal_22d.times.rescale('ms').magnitude) assert_array_equal(retrieved_issignal_22d.magnitude, original_issignal_22d.magnitude) - original_event_11 = original_blocks[1].segments[1].events[0] - retrieved_event_11 = retrieved_blocks[1].segments[1].events[0] - for attr_name in ("name",): - retrieved_attribute = getattr(retrieved_event_11, attr_name) - original_attribute = getattr(original_event_11, attr_name) - self.assertEqual(retrieved_attribute, original_attribute) - assert_array_equal(retrieved_event_11.rescale('ms').magnitude, - original_event_11.rescale('ms').magnitude) - assert_array_equal(retrieved_event_11.labels, original_event_11.labels) - original_spiketrain_131 = original_blocks[1].segments[1].spiketrains[1] retrieved_spiketrain_131 = retrieved_blocks[1].segments[1].spiketrains[1] for attr_name in ("name", "t_start", "t_stop"): @@ -170,23 +158,23 @@ def test_roundtrip(self): assert_array_equal(retrieved_spiketrain_131.times.rescale('ms').magnitude, original_spiketrain_131.times.rescale('ms').magnitude) - original_epoch_11 = original_blocks[1].segments[1].epochs[0] - retrieved_epoch_11 = retrieved_blocks[1].segments[1].epochs[0] - for attr_name in ("name",): - retrieved_attribute = getattr(retrieved_epoch_11, attr_name) - original_attribute = getattr(original_epoch_11, attr_name) - self.assertEqual(retrieved_attribute, original_attribute) - assert_array_equal(retrieved_epoch_11.rescale('ms').magnitude, - original_epoch_11.rescale('ms').magnitude) - assert_allclose(retrieved_epoch_11.durations.rescale('ms').magnitude, - original_epoch_11.durations.rescale('ms').magnitude) - assert_array_equal(retrieved_epoch_11.labels, original_epoch_11.labels) + results_roundtrip = list(inspect_nwb(nwbfile_path=test_file_name)) + os.remove(test_file_name) def test_roundtrip_with_annotations(self): - # test with NWB-specific annotations - original_block = Block(name="experiment", session_start_time=datetime.now()) + # Test with NWB-specific annotations + subject_annotations = {"nwb:subject_id": "011", + "nwb:age": "P90D", + "nwb:description": "mouse 4", + "nwb:species": "Mus musculus", + "nwb:sex": "F"} + original_block = Block(name="experiment", session_start_time=datetime.now(), + experimenter="Experimenter's name", + experiment_description="Experiment description", + institution="Institution", + subject=subject_annotations) segment = Segment(name="session 1") original_block.segments.append(segment) segment.block = original_block @@ -251,6 +239,79 @@ def test_roundtrip_with_annotations(self): self.assertEqual(retrieved_attribute, original_attribute) assert_array_equal(retrieved_response.magnitude, original_response.magnitude) + results_roundtrip_with_annotations = list(inspect_nwb(nwbfile_path=test_file_name)) + + os.remove(test_file_name) + + def test_roundtrip_with_not_constant_sampling_rate(self): + # To check NWB Inspector for Epoch and Event + # NWB Epochs = Neo Segments + # Should work for multiple segments, not for multiple blocks + # The specific test for Time Series not having a constant sample rate + # For epochs and events + + annotations = { + "session_start_time": datetime.now(), + } + # Define Neo blocks + bl0 = Block(name='First block', + experimenter="Experimenter's name", + experiment_description="Experiment description", + institution="Institution", + **annotations) + original_blocks = [bl0] + + num_seg = 2 # number of segments + num_chan = 3 # number of channels + + for j, blk in enumerate(original_blocks): + + for ind in range(num_seg): # number of Segments + + seg = Segment(index=ind) + seg.block = blk + blk.segments.append(seg) + + for i, seg in enumerate(blk.segments): # AnalogSignal objects + + a = AnalogSignal(name='Signal_a %s' % (seg.name), + signal=np.random.randn(44, num_chan) * pq.nA, + sampling_rate=10 * pq.kHz, + t_start=50 * pq.ms) + + epc = Epoch(times=[0 + i * ind, 10 + i * ind, 33 + i * ind] * pq.s, + durations=[10, 5, 7] * pq.s, + labels=np.array(['btn0', 'btn1', 'btn2'], dtype='U')) + + epc2 = Epoch(times=[0.1 + i * ind, 30 + i * ind, 61 + i * ind] * pq.s, + durations=[10, 5, 7] * pq.s, + labels=np.array(['btn4', 'btn5', 'btn6'])) + + evt = Event(name='Event', + times=[0.01 + i * ind, 11 + i * ind, 33 + i * ind] * pq.s, + labels=np.array(['ev0', 'ev1', 'ev2'])) + + seg.epochs.append(epc) + seg.epochs.append(epc2) + seg.events.append(evt) + seg.analogsignals.append(a) + + a.segment = seg + epc.segment = seg + epc2.segment = seg + evt.segment = seg + + test_file_name = "test_round_trip_with_not_constant_sampling_rate.nwb" + iow = NWBIO(filename=test_file_name, mode='w') + iow.write_all_blocks(original_blocks) + + ior = NWBIO(filename=test_file_name, mode='r') + retrieved_blocks = ior.read_all_blocks() + + self.assertEqual(len(retrieved_blocks), 1) + + results_roundtrip_specific_for_epochs = list(inspect_nwb(nwbfile_path=test_file_name)) + os.remove(test_file_name) def test_write_proxy_objects(self): diff --git a/requirements_testing.txt b/requirements_testing.txt index 4cec16df6..435539fb6 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -15,4 +15,5 @@ coveralls pillow sonpy pynwb +nwbinspector probeinterface diff --git a/setup.py b/setup.py index e519bcb6b..1267e57d7 100755 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ 'nixio': ['nixio>=1.5.0'], 'stimfitio': ['stfio'], 'tiffio': ['pillow'], + 'nwbio': ['pynwb', 'nwbinspector'], 'edf': ['pyedflib'] }