Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 123 additions & 51 deletions processor/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ class NWBElectricalSeriesReader:
"""
Wrapper class around the NWB ElectricalSeries object.

Provides helper functions and attributes for understanding the object's underlying sample and timeseries data
Provides helper functions and attributes for understanding the object's underlying sample and timeseries data.

Timestamps are computed on-demand to avoid loading the entire array into memory.

Attributes:
electrical_series (ElectricalSeries): Raw acquired data from a NWB file
num_samples(int): Number of samples per-channel
num_channels (int): Number of channels
sampling_rate (int): Sampling rate (in Hz) either given by the raw file or calculated from given timestamp values
timestamps (int): Timestamps (offset seconds from 0) either given by the raw file or calculated from given sampling rate
channels (list[TimeSeriesChannel]): list of channels and their respective metadata
"""

Expand All @@ -34,64 +35,90 @@ def __init__(self, electrical_series, session_start_time):
), "Electrode channels do not align with data shape"

self._sampling_rate = None
self._timestamps = None
self._compute_sampling_rate_and_timestamps()
self._compute_sampling_rate()

assert self.num_samples == len(self.timestamps), "Differing number of sample and timestamp value"
if self.has_explicit_timestamps:
assert self.num_samples == len(
self.electrical_series.timestamps
), "Differing number of sample and timestamp value"

self._channels = None

def _compute_sampling_rate_and_timestamps(self):
"""
Sets the sampling_rate and timestamps properties on the reader object.
@property
def has_explicit_timestamps(self):
return self.electrical_series.timestamps is not None

Computes either the sampling_rate or the timestamps given the other
is provided in the NWB file.
def _compute_sampling_rate(self):
"""
Computes and stores the sampling rate.

Note: NWB specifies timestamps in seconds
Note: NWB specifies timestamps in seconds.

Note: PyNWB disallows both sampling_rate and timestamps to be set on
TimeSeries objects but its worth handling this case by validating the
sampling_rate against the timestamps if this case does somehow appear
sampling_rate against the timestamps if this case does somehow appear.
"""
if self.electrical_series.rate is None and self.electrical_series.timestamps is None:
raise Exception("electrical series has no defined sampling rate or timestamp values")

# if both the timestamps and rate properties are set on the electrical series
# validate that the given rate is within a 2% margin of the rate calculated
# off of the given timestamps
if self.electrical_series.rate and self.electrical_series.timestamps:
# validate sampling rate against timestamps
timestamps = self.electrical_series.timestamps
# if both the timestamps and rate properties are set on the electrical
# series validate that the given rate is within a 2% margin of the
# sampling rate calculated off of the given timestamps
if self.electrical_series.rate and self.has_explicit_timestamps:
sampling_rate = self.electrical_series.rate

inferred_sampling_rate = infer_sampling_rate(timestamps)
sample_size = min(10000, self.num_samples)
sample_timestamps = self.electrical_series.timestamps[:sample_size]
inferred_sampling_rate = infer_sampling_rate(sample_timestamps)

error = abs(inferred_sampling_rate - sampling_rate) * (1.0 / sampling_rate)
if error > 0.02:
# error is greater than 2%
raise Exception(
"Inferred rate from timestamps ({inferred_rate:.4f}) does not match given rate ({given_rate:.4f}).".format(
inferred_rate=inferred_sampling_rate, given_rate=sampling_rate
)
)
self._sampling_rate = sampling_rate

# if only the rate is given, calculate the timestamps for the samples
# using the given number of samples (size of the data)
if self.electrical_series.rate:
sampling_rate = self.electrical_series.rate
timestamps = np.linspace(0, self.num_samples / sampling_rate, self.num_samples, endpoint=False)
# if only the rate is given, timestamps will be computed on-demand
elif self.electrical_series.rate:
self._sampling_rate = self.electrical_series.rate

# if only the timestamps are given, calculate the sampling rate using the timestamps
if self.electrical_series.timestamps:
timestamps = self.electrical_series.timestamps
sampling_rate = round(infer_sampling_rate(self._timestamps))
# if only the timestamps are given, calculate the sampling rate using a sample of timestamps
elif self.has_explicit_timestamps:
sample_size = min(10000, self.num_samples)
sample_timestamps = self.electrical_series.timestamps[:sample_size]
self._sampling_rate = round(infer_sampling_rate(sample_timestamps))

self._sampling_rate = sampling_rate
self._timestamps = timestamps + self.session_start_time_secs
def get_timestamp(self, index):
"""
Get timestamp for a single sample index.
Computes on-demand when timestamps are not explicitly set.
"""
timestamp = (
float(self.electrical_series.timestamps[index])
if self.has_explicit_timestamps
else (index / self._sampling_rate)
)
return timestamp + self.session_start_time_secs

@property
def timestamps(self):
return self._timestamps
def get_timestamps(self, start, end):
"""
Get timestamps for a range of indices [start, end).
Computes on-demand when timestamps are not explicitly set.
Returns a numpy array.
"""
timestamps = (
np.array(self.electrical_series.timestamps[start:end])
if self.has_explicit_timestamps
else np.linspace(
start / self._sampling_rate,
end / self._sampling_rate,
end - start,
endpoint=False,
)
)
return timestamps + self.session_start_time_secs

@property
def sampling_rate(self):
Expand All @@ -101,6 +128,10 @@ def sampling_rate(self):
def channels(self):
if not self._channels:
channels = []

start_timestamp = self.get_timestamp(0)
end_timestamp = self.get_timestamp(self.num_samples - 1)

for index, electrode in enumerate(self.electrical_series.electrodes):
name = ""
if isinstance(electrode, DataFrame):
Expand All @@ -122,8 +153,8 @@ def channels(self):
index=index,
name=name,
rate=self.sampling_rate,
start=self.timestamps[0] * 1e6, # safe access gaurenteed by initialization assertions
end=self.timestamps[-1] * 1e6,
start=start_timestamp * 1e6,
end=end_timestamp * 1e6,
group=group_name,
)
)
Expand All @@ -144,28 +175,69 @@ def contiguous_chunks(self):

(timestamp_difference) > 2 * sampling_period
"""
# if no explicit timestamps, data is continuous by definition
if not self.has_explicit_timestamps:
yield 0, self.num_samples
return

# process timestamps in batches to find gaps without loading all into memory
gap_threshold = (1.0 / self.sampling_rate) * 2
batch_size = 100000

boundaries = np.concatenate(
([0], (np.diff(self.timestamps) > gap_threshold).nonzero()[0] + 1, [len(self.timestamps)])
)
boundaries = [0]
prev_timestamp = None

for batch_start in range(0, self.num_samples, batch_size):
batch_end = min(batch_start + batch_size, self.num_samples)
batch_timestamps = self.electrical_series.timestamps[batch_start:batch_end]

for i in np.arange(len(boundaries) - 1):
# check gap between batches
if prev_timestamp is not None:
if batch_timestamps[0] - prev_timestamp > gap_threshold:
boundaries.append(batch_start)

# find gaps within batch
diffs = np.diff(batch_timestamps)
gap_indices = np.where(diffs > gap_threshold)[0]
for gap_idx in gap_indices:
boundaries.append(batch_start + gap_idx + 1)

prev_timestamp = batch_timestamps[-1]

boundaries.append(self.num_samples)

for i in range(len(boundaries) - 1):
yield boundaries[i], boundaries[i + 1]

def get_chunk(self, channel_index, start=None, end=None):
def get_chunk(self, start=None, end=None):
"""
Returns a chunk of sample data from the electrical series
for the given channel (index)
Returns chunks of sample data across all channels in a single HDF5 read.

If start and end are not specified the entire channel's data is read into memory.
If start and end are not specified all data is read into memory.

The sample data is scaled by the conversion and offset factors
set in the electrical series.
"""
scale_factor = self.electrical_series.conversion
HDF5 is optimized for contiguous reads, so reading all channels at once
and splitting in memory is much faster than column-by-column access.

if self.electrical_series.channel_conversion:
scale_factor *= self.electrical_series.channel_conversion[channel_index]
Args:
start: Start sample index (default: 0)
end: End sample index (default: num_samples)

return self.electrical_series.data[start:end, channel_index] * scale_factor + self.electrical_series.offset
Returns:
list of numpy arrays, one per channel, with scaling applied
"""
# Single HDF5 read for all channels
all_data = self.electrical_series.data[start:end, :]

base_scale = self.electrical_series.conversion
offset = self.electrical_series.offset

# Apply per-channel scaling if present
if self.electrical_series.channel_conversion is not None:
channel_scales = np.array(self.electrical_series.channel_conversion) * base_scale
# Broadcast multiply: (samples, channels) * (channels,) -> (samples, channels)
scaled_data = all_data * channel_scales + offset
else:
scaled_data = all_data * base_scale + offset

# Split into list of per-channel arrays
return [scaled_data[:, i] for i in range(self.num_channels)]
63 changes: 47 additions & 16 deletions processor/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor

import numpy as np
from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION
Expand All @@ -24,47 +25,77 @@ def __init__(self, session_start_time, output_dir, chunk_size):
self.output_dir = output_dir
self.chunk_size = chunk_size

def write_electrical_series(self, electrical_series):
def write_electrical_series(self, electrical_series, max_workers=None):
"""
Chunks the sample data in two stages:
1. Splits sample data into contiguous segments using the given or generated timestamp values
2. Chunks each contiguous segment into the given chunk_size (number of samples to include per file)

Writes each chunk to the given output directory
Writes each chunk to the given output directory.
Channel processing is parallelized using ThreadPoolExecutor. Threads share memory
(no serialization overhead) and the GIL is released during gzip compression and file I/O.

Args:
electrical_series: NWB ElectricalSeries object
max_workers: Maximum number of threads (defaults to min(32, cpu_count + 4))
"""
reader = NWBElectricalSeriesReader(electrical_series, self.session_start_time)
num_channels = len(reader.channels)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for contiguous_start, contiguous_end in reader.contiguous_chunks():
for chunk_start in range(contiguous_start, contiguous_end, self.chunk_size):
chunk_end = min(contiguous_end, chunk_start + self.chunk_size)

start_time = reader.get_timestamp(chunk_start)
end_time = reader.get_timestamp(chunk_end - 1)

for contiguous_start, contiguous_end in reader.contiguous_chunks():
for chunk_start in range(contiguous_start, contiguous_end, self.chunk_size):
chunk_end = min(contiguous_end, chunk_start + self.chunk_size)
channel_chunks = reader.get_chunk(chunk_start, chunk_end)

start_time = reader.timestamps[chunk_start]
end_time = reader.timestamps[chunk_end - 1]
futures = [
executor.submit(
self.write_chunk,
channel_chunks[i],
start_time,
end_time,
reader.channels[i].index,
self.output_dir,
)
for i in range(num_channels)
]

for channel_index in range(len(reader.channels)):
chunk = reader.get_chunk(channel_index, chunk_start, chunk_end)
channel = reader.channels[channel_index]
self.write_chunk(chunk, start_time, end_time, channel)
for future in futures:
future.result()

for channel in reader.channels:
self.write_channel(channel)

def write_chunk(self, chunk, start_time, end_time, channel):
@staticmethod
def write_chunk(chunk, start_time, end_time, channel_index, output_dir):
"""
Formats the chunked sample data into 64-bit (8 byte) values in big-endian.

Writes the chunked sample data to a gzipped binary file.

Args:
chunk: numpy array of sample data for the channel
start_time: start timestamp in seconds
end_time: end timestamp in seconds
channel_index: channel index for output filename
output_dir: directory to write chunked output file
"""
# ensure the samples are 64-bit float-pointing numbers in big-endian before converting to bytes
formatted_data = to_big_endian(chunk.astype(np.float64))

channel_index = "{index:05d}".format(index=channel.index)
file_name = "channel-{}_{}_{}{}".format(
channel_index, int(start_time * 1e6), int(end_time * 1e6), TIME_SERIES_BINARY_FILE_EXTENSION
"{index:05d}".format(index=channel_index),
int(start_time * 1e6),
int(end_time * 1e6),
TIME_SERIES_BINARY_FILE_EXTENSION,
)
file_path = os.path.join(self.output_dir, file_name)
file_path = os.path.join(output_dir, file_name)

with gzip.open(file_path, mode="wb", compresslevel=1) as f:
with gzip.open(file_path, mode="wb", compresslevel=0) as f:
f.write(formatted_data)

def write_channel(self, channel):
Expand Down
Loading