From 118217c4f05ad4ab5c10fb22b26f5f66f256c79d Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Mon, 15 Dec 2025 14:54:07 -0500 Subject: [PATCH 01/10] calculate timestamps instead of computing into memory --- processor/reader.py | 124 +++++++++++++++++++++++++++++++------------- processor/writer.py | 4 +- 2 files changed, 89 insertions(+), 39 deletions(-) diff --git a/processor/reader.py b/processor/reader.py index 15196c2..79836d8 100644 --- a/processor/reader.py +++ b/processor/reader.py @@ -12,14 +12,15 @@ class NWBElectricalSeriesReader: """ Wrapper class around the NWB ElectricalSeries object. - Provides helper functions and attributes for understanding the object's underlying sample and timeseries data + Provides helper functions and attributes for understanding the object's underlying sample and timeseries data. + + Timestamps are computed on-demand to avoid loading the entire array into memory. Attributes: electrical_series (ElectricalSeries): Raw acquired data from a NWB file num_samples(int): Number of samples per-channel num_channels (int): Number of channels sampling_rate (int): Sampling rate (in Hz) either given by the raw file or calculated from given timestamp values - timestamps (int): Timestamps (offset seconds from 0) either given by the raw file or calculated from given sampling rate channels (list[TimeSeriesChannel]): list of channels and their respective metadata """ @@ -32,20 +33,14 @@ def __init__(self, electrical_series, session_start_time): assert len(self.electrical_series.electrodes.table) == self.num_channels, 'Electrode channels do not align with data shape' self._sampling_rate = None - self._timestamps = None - self._compute_sampling_rate_and_timestamps() - - assert self.num_samples == len(self.timestamps), "Differing number of sample and timestamp value" + self._has_explicit_timestamps = False + self._compute_sampling_rate() self._channels = None - - def _compute_sampling_rate_and_timestamps(self): + def _compute_sampling_rate(self): """ - Sets the sampling_rate and timestamps properties on the reader object. - - Computes either the sampling_rate or the timestamps given the other - is provided in the NWB file. + Computes and stores the sampling rate. Note: NWB specifies timestamps in seconds @@ -59,35 +54,55 @@ def _compute_sampling_rate_and_timestamps(self): # if both the timestamps and rate properties are set on the electrical series # validate that the given rate is within a 2% margin of the rate calculated # off of the given timestamps - if self.electrical_series.rate and self.electrical_series.timestamps: - # validate sampling rate against timestamps - timestamps = self.electrical_series.timestamps + if self.electrical_series.rate and self.electrical_series.timestamps is not None: + self._has_explicit_timestamps = True sampling_rate = self.electrical_series.rate - inferred_sampling_rate = infer_sampling_rate(timestamps) - error = abs(inferred_sampling_rate-sampling_rate) * (1.0 / sampling_rate) + # sample a small portion of timestamps to infer rate + sample_size = min(10000, self.num_samples) + sample_timestamps = self.electrical_series.timestamps[:sample_size] + inferred_sampling_rate = infer_sampling_rate(sample_timestamps) + + error = abs(inferred_sampling_rate - sampling_rate) * (1.0 / sampling_rate) if error > 0.02: - # error is greater than 2% raise Exception("Inferred rate from timestamps ({inferred_rate:.4f}) does not match given rate ({given_rate:.4f})." \ .format(inferred_rate=inferred_sampling_rate, given_rate=sampling_rate)) + self._sampling_rate = sampling_rate - # if only the rate is given, calculate the timestamps for the samples - # using the given number of samples (size of the data) - if self.electrical_series.rate: - sampling_rate = self.electrical_series.rate - timestamps = np.linspace(0, self.num_samples / sampling_rate, self.num_samples, endpoint = False) + # if only the rate is given, timestamps will be computed on-demand + elif self.electrical_series.rate: + self._sampling_rate = self.electrical_series.rate + self._has_explicit_timestamps = False - # if only the timestamps are given, calculate the sampling rate using the timestamps - if self.electrical_series.timestamps: - timestamps = self.electrical_series.timestamps - sampling_rate = round(infer_sampling_rate(self._timestamps)) + # if only the timestamps are given, calculate the sampling rate using a sample + elif self.electrical_series.timestamps is not None: + self._has_explicit_timestamps = True + sample_size = min(10000, self.num_samples) + sample_timestamps = self.electrical_series.timestamps[:sample_size] + self._sampling_rate = round(infer_sampling_rate(sample_timestamps)) - self._sampling_rate = sampling_rate - self._timestamps = timestamps + self.session_start_time_secs + def get_timestamp(self, index): + """ + Get timestamp for a single sample index. Computes on-demand. + """ + if self._has_explicit_timestamps: + return float(self.electrical_series.timestamps[index]) + self.session_start_time_secs + else: + return (index / self._sampling_rate) + self.session_start_time_secs - @property - def timestamps(self): - return self._timestamps + def get_timestamps(self, start, end): + """ + Get timestamps for a range of indices [start, end). Returns a numpy array. + """ + if self._has_explicit_timestamps: + return np.array(self.electrical_series.timestamps[start:end]) + self.session_start_time_secs + else: + return np.linspace( + start / self._sampling_rate, + end / self._sampling_rate, + end - start, + endpoint=False + ) + self.session_start_time_secs @property def sampling_rate(self): @@ -97,6 +112,11 @@ def sampling_rate(self): def channels(self): if not self._channels: channels = list() + + # compute start/end timestamps on-demand + start_timestamp = self.get_timestamp(0) + end_timestamp = self.get_timestamp(self.num_samples - 1) + for index, electrode in enumerate(self.electrical_series.electrodes): name = "" if isinstance(electrode, DataFrame): @@ -118,8 +138,8 @@ def channels(self): index = index, name = name, rate = self.sampling_rate, - start = self.timestamps[0] * 1e6 , # safe access gaurenteed by initialization assertions - end = self.timestamps[-1] * 1e6, + start = start_timestamp * 1e6, + end = end_timestamp * 1e6, group = group_name ) ) @@ -139,13 +159,43 @@ def contiguous_chunks(self): sampling_period = 1 / sampling_rate (timestamp_difference) > 2 * sampling_period + + For data with only a sampling rate (no explicit timestamps), the entire + dataset is treated as one contiguous chunk. """ + # if no explicit timestamps, data is continuous by definition + if not self._has_explicit_timestamps: + yield 0, self.num_samples + return + + # process timestamps in batches to find gaps without loading all into memory gap_threshold = (1.0 / self.sampling_rate) * 2 + batch_size = 100000 + + boundaries = [0] + prev_timestamp = None + + for batch_start in range(0, self.num_samples, batch_size): + batch_end = min(batch_start + batch_size, self.num_samples) + batch_timestamps = self.electrical_series.timestamps[batch_start:batch_end] + + # check gap between batches + if prev_timestamp is not None: + if batch_timestamps[0] - prev_timestamp > gap_threshold: + boundaries.append(batch_start) + + # find gaps within batch + diffs = np.diff(batch_timestamps) + gap_indices = np.where(diffs > gap_threshold)[0] + for gap_idx in gap_indices: + boundaries.append(batch_start + gap_idx + 1) + + prev_timestamp = batch_timestamps[-1] - boundaries = np.concatenate( - ([0], (np.diff(self.timestamps) > gap_threshold).nonzero()[0] + 1, [len(self.timestamps)])) + boundaries.append(self.num_samples) - for i in np.arange(len(boundaries)-1): + # yield contiguous ranges + for i in range(len(boundaries) - 1): yield boundaries[i], boundaries[i + 1] def get_chunk(self, channel_index, start = None, end = None): diff --git a/processor/writer.py b/processor/writer.py index a6c578c..47c04a1 100644 --- a/processor/writer.py +++ b/processor/writer.py @@ -37,8 +37,8 @@ def write_electrical_series(self, electrical_series): for chunk_start in range(contiguous_start, contiguous_end, self.chunk_size): chunk_end = min(contiguous_end, chunk_start + self.chunk_size) - start_time = reader.timestamps[chunk_start] - end_time = reader.timestamps[chunk_end-1] + start_time = reader.get_timestamp(chunk_start) + end_time = reader.get_timestamp(chunk_end - 1) for channel_index in range(len(reader.channels)): chunk = reader.get_chunk(channel_index, chunk_start, chunk_end) From 902cb9431497c69cd86cce563fc8884526505d6b Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Mon, 15 Dec 2025 15:17:19 -0500 Subject: [PATCH 02/10] cleanup --- processor/reader.py | 48 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/processor/reader.py b/processor/reader.py index 79836d8..7ba41b1 100644 --- a/processor/reader.py +++ b/processor/reader.py @@ -42,23 +42,22 @@ def _compute_sampling_rate(self): """ Computes and stores the sampling rate. - Note: NWB specifies timestamps in seconds + Note: NWB specifies timestamps in seconds. Note: PyNWB disallows both sampling_rate and timestamps to be set on TimeSeries objects but its worth handling this case by validating the - sampling_rate against the timestamps if this case does somehow appear + sampling_rate against the timestamps if this case does somehow appear. """ if self.electrical_series.rate is None and self.electrical_series.timestamps is None: raise Exception("electrical series has no defined sampling rate or timestamp values") - # if both the timestamps and rate properties are set on the electrical series - # validate that the given rate is within a 2% margin of the rate calculated - # off of the given timestamps + # if both the timestamps and rate properties are set on the electrical + # series validate that the given rate is within a 2% margin of the + # sampling rate calculated off of the given timestamps if self.electrical_series.rate and self.electrical_series.timestamps is not None: self._has_explicit_timestamps = True sampling_rate = self.electrical_series.rate - # sample a small portion of timestamps to infer rate sample_size = min(10000, self.num_samples) sample_timestamps = self.electrical_series.timestamps[:sample_size] inferred_sampling_rate = infer_sampling_rate(sample_timestamps) @@ -74,7 +73,7 @@ def _compute_sampling_rate(self): self._sampling_rate = self.electrical_series.rate self._has_explicit_timestamps = False - # if only the timestamps are given, calculate the sampling rate using a sample + # if only the timestamps are given, calculate the sampling rate using a sample of timestamps elif self.electrical_series.timestamps is not None: self._has_explicit_timestamps = True sample_size = min(10000, self.num_samples) @@ -83,26 +82,32 @@ def _compute_sampling_rate(self): def get_timestamp(self, index): """ - Get timestamp for a single sample index. Computes on-demand. + Get timestamp for a single sample index. + Computes on-demand when timestamps are not explicitly set. """ - if self._has_explicit_timestamps: - return float(self.electrical_series.timestamps[index]) + self.session_start_time_secs - else: - return (index / self._sampling_rate) + self.session_start_time_secs + timestamp = ( + float(self.electrical_series.timestamps[index]) + if self._has_explicit_timestamps + else (index / self._sampling_rate) + return timestamp + self.session_start_time_secs def get_timestamps(self, start, end): """ - Get timestamps for a range of indices [start, end). Returns a numpy array. + Get timestamps for a range of indices [start, end). + Computes on-demand when timestamps are not explicitly set. + Returns a numpy array. """ - if self._has_explicit_timestamps: - return np.array(self.electrical_series.timestamps[start:end]) + self.session_start_time_secs - else: - return np.linspace( + timestamps = ( + np.array(self.electrical_series.timestamps[start:end]) + if self._has_explicit_timestamps + else np.linspace( start / self._sampling_rate, end / self._sampling_rate, end - start, - endpoint=False - ) + self.session_start_time_secs + endpoint=False, + ) + ) + return timestamps + self.session_start_time_secs @property def sampling_rate(self): @@ -113,7 +118,6 @@ def channels(self): if not self._channels: channels = list() - # compute start/end timestamps on-demand start_timestamp = self.get_timestamp(0) end_timestamp = self.get_timestamp(self.num_samples - 1) @@ -159,9 +163,6 @@ def contiguous_chunks(self): sampling_period = 1 / sampling_rate (timestamp_difference) > 2 * sampling_period - - For data with only a sampling rate (no explicit timestamps), the entire - dataset is treated as one contiguous chunk. """ # if no explicit timestamps, data is continuous by definition if not self._has_explicit_timestamps: @@ -194,7 +195,6 @@ def contiguous_chunks(self): boundaries.append(self.num_samples) - # yield contiguous ranges for i in range(len(boundaries) - 1): yield boundaries[i], boundaries[i + 1] From a420ae403041ffb8ccbf07c4334bc8820b41ebc9 Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Tue, 16 Dec 2025 14:13:43 -0500 Subject: [PATCH 03/10] fix missing parenthesis --- processor/reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/processor/reader.py b/processor/reader.py index 7ba41b1..e2d3b4f 100644 --- a/processor/reader.py +++ b/processor/reader.py @@ -89,6 +89,7 @@ def get_timestamp(self, index): float(self.electrical_series.timestamps[index]) if self._has_explicit_timestamps else (index / self._sampling_rate) + ) return timestamp + self.session_start_time_secs def get_timestamps(self, start, end): From 5ce3a983d4a5635b2150925692e5fdd5088ac2cb Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Tue, 16 Dec 2025 16:16:05 -0500 Subject: [PATCH 04/10] parallelize chucked writer --- processor/writer.py | 70 +++++++++++++++++----- tests/test_writer.py | 134 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 188 insertions(+), 16 deletions(-) diff --git a/processor/writer.py b/processor/writer.py index 769bb9f..e1801f1 100644 --- a/processor/writer.py +++ b/processor/writer.py @@ -2,6 +2,7 @@ import json import logging import os +from concurrent.futures import ProcessPoolExecutor import numpy as np from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION @@ -11,6 +12,29 @@ log = logging.getLogger() +def _write_channel_chunk_worker(args): + """ + Worker function for parallel channel chunk processing. + Must be a top-level function to be picklable for ProcessPoolExecutor. + + Args: + args: Tuple of (chunk_data, start_time, end_time, channel_index, output_dir) + """ + chunk_data, start_time, end_time, channel_index, output_dir = args + + # Convert to big-endian format + formatted_data = to_big_endian(chunk_data.astype(np.float64)) + + channel_index_str = "{index:05d}".format(index=channel_index) + file_name = "channel-{}_{}_{}{}".format( + channel_index_str, int(start_time * 1e6), int(end_time * 1e6), TIME_SERIES_BINARY_FILE_EXTENSION + ) + file_path = os.path.join(output_dir, file_name) + + with gzip.open(file_path, mode="wb", compresslevel=1) as f: + f.write(formatted_data) + + class TimeSeriesChunkWriter: """ Attributes: @@ -24,27 +48,45 @@ def __init__(self, session_start_time, output_dir, chunk_size): self.output_dir = output_dir self.chunk_size = chunk_size - def write_electrical_series(self, electrical_series): + def write_electrical_series(self, electrical_series, max_workers=None): """ Chunks the sample data in two stages: 1. Splits sample data into contiguous segments using the given or generated timestamp values 2. Chunks each contiguous segment into the given chunk_size (number of samples to include per file) - Writes each chunk to the given output directory + Writes each chunk to the given output directory. + Channel processing is parallelized using ProcessPoolExecutor for improved performance + with datasets containing many channels (64-384 typical in neuroscience). + + Args: + electrical_series: NWB ElectricalSeries object + max_workers: Maximum number of worker processes (defaults to CPU count) """ reader = NWBElectricalSeriesReader(electrical_series, self.session_start_time) - - for contiguous_start, contiguous_end in reader.contiguous_chunks(): - for chunk_start in range(contiguous_start, contiguous_end, self.chunk_size): - chunk_end = min(contiguous_end, chunk_start + self.chunk_size) - - start_time = reader.get_timestamp(chunk_start) - end_time = reader.get_timestamp(chunk_end - 1) - - for channel_index in range(len(reader.channels)): - chunk = reader.get_chunk(channel_index, chunk_start, chunk_end) - channel = reader.channels[channel_index] - self.write_chunk(chunk, start_time, end_time, channel) + num_channels = len(reader.channels) + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for contiguous_start, contiguous_end in reader.contiguous_chunks(): + for chunk_start in range(contiguous_start, contiguous_end, self.chunk_size): + chunk_end = min(contiguous_end, chunk_start + self.chunk_size) + + start_time = reader.get_timestamp(chunk_start) + end_time = reader.get_timestamp(chunk_end - 1) + + # Read all channel data for this chunk at once from HDF5 + # (HDF5 doesn't support efficient concurrent reads) + channel_chunks = [ + reader.get_chunk(channel_index, chunk_start, chunk_end) for channel_index in range(num_channels) + ] + + # Prepare arguments for parallel processing + worker_args = [ + (channel_chunks[i], start_time, end_time, reader.channels[i].index, self.output_dir) + for i in range(num_channels) + ] + + # Process all channels in parallel + list(executor.map(_write_channel_chunk_worker, worker_args)) for channel in reader.channels: self.write_channel(channel) diff --git a/tests/test_writer.py b/tests/test_writer.py index 4ffabdb..b09f252 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -6,7 +6,7 @@ import numpy as np from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION from timeseries_channel import TimeSeriesChannel -from writer import TimeSeriesChunkWriter +from writer import TimeSeriesChunkWriter, _write_channel_chunk_worker class TestTimeSeriesChunkWriterInit: @@ -189,6 +189,7 @@ def test_write_electrical_series_single_chunk(self, temp_output_dir, session_sta mock_reader.timestamps = np.linspace(0, 0.5, 500, endpoint=False) mock_reader.contiguous_chunks.return_value = [(0, 500)] mock_reader.get_chunk.return_value = np.random.randn(500).astype(np.float64) + mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): mock_series = Mock() @@ -215,6 +216,7 @@ def test_write_electrical_series_multiple_chunks(self, temp_output_dir, session_ mock_reader.timestamps = np.linspace(0, 0.25, 250, endpoint=False) mock_reader.contiguous_chunks.return_value = [(0, 250)] mock_reader.get_chunk.return_value = np.random.randn(100).astype(np.float64) + mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): mock_series = Mock() @@ -241,6 +243,7 @@ def test_write_electrical_series_with_gap(self, temp_output_dir, session_start_t mock_reader.timestamps = np.concatenate([timestamps_seg1, timestamps_seg2]) mock_reader.contiguous_chunks.return_value = [(0, 100), (100, 200)] mock_reader.get_chunk.return_value = np.random.randn(100).astype(np.float64) + mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): mock_series = Mock() @@ -259,9 +262,11 @@ def test_write_electrical_series_chunk_timestamps(self, temp_output_dir, session mock_reader = Mock() mock_reader.channels = [TimeSeriesChannel(index=0, name="Ch0", rate=1000.0, start=0, end=1000)] # 100 samples at 1000 Hz = 0.1 seconds - mock_reader.timestamps = np.linspace(1.0, 1.1, 100, endpoint=False) + timestamps = np.linspace(1.0, 1.1, 100, endpoint=False) + mock_reader.timestamps = timestamps mock_reader.contiguous_chunks.return_value = [(0, 100)] mock_reader.get_chunk.return_value = np.random.randn(50).astype(np.float64) + mock_reader.get_timestamp.side_effect = lambda idx: float(timestamps[idx]) with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): mock_series = Mock() @@ -339,3 +344,128 @@ def test_write_chunk_special_float_values(self, temp_output_dir, session_start_t assert np.isinf(result[0]) and result[0] > 0 assert np.isinf(result[1]) and result[1] < 0 assert np.isnan(result[2]) + + +class TestParallelProcessing: + """Tests for parallel channel processing functionality.""" + + def test_write_channel_chunk_worker(self, temp_output_dir): + """Test the worker function directly.""" + chunk_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float64) + start_time = 1.0 + end_time = 1.005 + channel_index = 0 + + args = (chunk_data, start_time, end_time, channel_index, temp_output_dir) + _write_channel_chunk_worker(args) + + # Check file was created + expected_filename = ( + f"channel-00000_{int(start_time * 1e6)}_{int(end_time * 1e6)}{TIME_SERIES_BINARY_FILE_EXTENSION}" + ) + file_path = os.path.join(temp_output_dir, expected_filename) + assert os.path.exists(file_path) + + # Verify data integrity + with gzip.open(file_path, "rb") as f: + data = f.read() + result = np.frombuffer(data, dtype=">f8") + np.testing.assert_array_equal(result, [1.0, 2.0, 3.0, 4.0, 5.0]) + + def test_write_channel_chunk_worker_big_endian(self, temp_output_dir): + """Test that worker function writes data in big-endian format.""" + chunk_data = np.array([1.5, -2.5], dtype=np.float64) + args = (chunk_data, 0.0, 0.001, 5, temp_output_dir) + _write_channel_chunk_worker(args) + + file_path = os.path.join(temp_output_dir, "channel-00005_0_1000.bin.gz") + with gzip.open(file_path, "rb") as f: + data = f.read() + + result = np.frombuffer(data, dtype=">f8") + np.testing.assert_array_equal(result, [1.5, -2.5]) + + def test_parallel_processing_many_channels(self, temp_output_dir, session_start_time): + """Test parallel processing with many channels (typical neuroscience scenario).""" + num_channels = 64 + writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, chunk_size=1000) + + mock_reader = Mock() + mock_reader.channels = [ + TimeSeriesChannel(index=i, name=f"Ch{i}", rate=30000.0, start=0, end=1000) for i in range(num_channels) + ] + mock_reader.contiguous_chunks.return_value = [(0, 500)] + mock_reader.get_chunk.return_value = np.random.randn(500).astype(np.float64) + mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 30000.0 + + with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): + mock_series = Mock() + writer.write_electrical_series(mock_series) + + files = os.listdir(temp_output_dir) + bin_files = [f for f in files if f.endswith(".bin.gz")] + json_files = [f for f in files if f.endswith(".metadata.json")] + + # Should have 64 binary files (1 chunk x 64 channels) + assert len(bin_files) == num_channels + # Should have 64 metadata files + assert len(json_files) == num_channels + + def test_parallel_processing_with_max_workers(self, temp_output_dir, session_start_time): + """Test that max_workers parameter is respected.""" + writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, chunk_size=1000) + + mock_reader = Mock() + mock_reader.channels = [ + TimeSeriesChannel(index=i, name=f"Ch{i}", rate=1000.0, start=0, end=1000) for i in range(8) + ] + mock_reader.contiguous_chunks.return_value = [(0, 500)] + mock_reader.get_chunk.return_value = np.random.randn(500).astype(np.float64) + mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 + + with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): + mock_series = Mock() + # Limit to 2 workers + writer.write_electrical_series(mock_series, max_workers=2) + + files = os.listdir(temp_output_dir) + bin_files = [f for f in files if f.endswith(".bin.gz")] + + # Should still produce correct output + assert len(bin_files) == 8 + + def test_parallel_processing_data_integrity(self, temp_output_dir, session_start_time): + """Test that parallel processing maintains data integrity across channels.""" + num_channels = 4 + writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, chunk_size=100) + + # Create distinct data for each channel + channel_data = {i: np.arange(100, dtype=np.float64) + i * 1000 for i in range(num_channels)} + + mock_reader = Mock() + mock_reader.channels = [ + TimeSeriesChannel(index=i, name=f"Ch{i}", rate=1000.0, start=0, end=1000) for i in range(num_channels) + ] + mock_reader.contiguous_chunks.return_value = [(0, 100)] + mock_reader.get_chunk.side_effect = lambda ch_idx, start, end: channel_data[ch_idx] + mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 + + with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): + mock_series = Mock() + writer.write_electrical_series(mock_series) + + # Verify each channel's data + for i in range(num_channels): + file_pattern = f"channel-{i:05d}_" + matching_files = [ + f for f in os.listdir(temp_output_dir) if f.startswith(file_pattern) and f.endswith(".bin.gz") + ] + assert len(matching_files) == 1 + + file_path = os.path.join(temp_output_dir, matching_files[0]) + with gzip.open(file_path, "rb") as f: + data = f.read() + + result = np.frombuffer(data, dtype=">f8") + expected = np.arange(100, dtype=np.float64) + i * 1000 + np.testing.assert_array_equal(result, expected) From ce0604a85078130705b16f4976e47e8a260dcc90 Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Tue, 16 Dec 2025 16:20:16 -0500 Subject: [PATCH 05/10] cleanup --- processor/writer.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/processor/writer.py b/processor/writer.py index e1801f1..fdd9573 100644 --- a/processor/writer.py +++ b/processor/writer.py @@ -12,16 +12,18 @@ log = logging.getLogger() -def _write_channel_chunk_worker(args): +def _write_channel_chunk_worker(chunk_data, start_time, end_time, channel_index, output_dir): """ Worker function for parallel channel chunk processing. Must be a top-level function to be picklable for ProcessPoolExecutor. Args: - args: Tuple of (chunk_data, start_time, end_time, channel_index, output_dir) + chunk_data: numpy array of sample data for the channel + start_time: start timestamp in seconds + end_time: end timestamp in seconds + channel_index: channel index for filename + output_dir: directory to write output file """ - chunk_data, start_time, end_time, channel_index, output_dir = args - # Convert to big-endian format formatted_data = to_big_endian(chunk_data.astype(np.float64)) @@ -79,14 +81,22 @@ def write_electrical_series(self, electrical_series, max_workers=None): reader.get_chunk(channel_index, chunk_start, chunk_end) for channel_index in range(num_channels) ] - # Prepare arguments for parallel processing - worker_args = [ - (channel_chunks[i], start_time, end_time, reader.channels[i].index, self.output_dir) + # Submit all channels for parallel processing + futures = [ + executor.submit( + _write_channel_chunk_worker, + channel_chunks[i], + start_time, + end_time, + reader.channels[i].index, + self.output_dir, + ) for i in range(num_channels) ] - # Process all channels in parallel - list(executor.map(_write_channel_chunk_worker, worker_args)) + # Wait for all to complete + for future in futures: + future.result() for channel in reader.channels: self.write_channel(channel) From 138d1669f2033f65457cd9b49d373f5177fbd87c Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Tue, 16 Dec 2025 16:21:01 -0500 Subject: [PATCH 06/10] fix tests --- tests/test_writer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_writer.py b/tests/test_writer.py index b09f252..dce9a1c 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -356,8 +356,7 @@ def test_write_channel_chunk_worker(self, temp_output_dir): end_time = 1.005 channel_index = 0 - args = (chunk_data, start_time, end_time, channel_index, temp_output_dir) - _write_channel_chunk_worker(args) + _write_channel_chunk_worker(chunk_data, start_time, end_time, channel_index, temp_output_dir) # Check file was created expected_filename = ( @@ -375,8 +374,7 @@ def test_write_channel_chunk_worker(self, temp_output_dir): def test_write_channel_chunk_worker_big_endian(self, temp_output_dir): """Test that worker function writes data in big-endian format.""" chunk_data = np.array([1.5, -2.5], dtype=np.float64) - args = (chunk_data, 0.0, 0.001, 5, temp_output_dir) - _write_channel_chunk_worker(args) + _write_channel_chunk_worker(chunk_data, 0.0, 0.001, 5, temp_output_dir) file_path = os.path.join(temp_output_dir, "channel-00005_0_1000.bin.gz") with gzip.open(file_path, "rb") as f: From bd5205f249dfa50e8c57f4adacca6012ade964b2 Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Wed, 17 Dec 2025 11:17:54 -0500 Subject: [PATCH 07/10] refactor to read chunk across channels and split in memory --- processor/reader.py | 38 ++++++++++----- processor/writer.py | 48 +++++++------------ tests/test_reader.py | 111 ++++++++++++++++++------------------------- tests/test_writer.py | 39 +++++++-------- 4 files changed, 107 insertions(+), 129 deletions(-) diff --git a/processor/reader.py b/processor/reader.py index fafd247..903e367 100644 --- a/processor/reader.py +++ b/processor/reader.py @@ -204,19 +204,35 @@ def contiguous_chunks(self): for i in range(len(boundaries) - 1): yield boundaries[i], boundaries[i + 1] - def get_chunk(self, channel_index, start=None, end=None): + def get_chunk(self, start=None, end=None): """ - Returns a chunk of sample data from the electrical series - for the given channel (index) + Returns chunks of sample data across all channels in a single HDF5 read. - If start and end are not specified the entire channel's data is read into memory. + If start and end are not specified all data is read into memory. - The sample data is scaled by the conversion and offset factors - set in the electrical series. - """ - scale_factor = self.electrical_series.conversion + HDF5 is optimized for contiguous reads, so reading all channels at once + and splitting in memory is much faster than column-by-column access. - if self.electrical_series.channel_conversion: - scale_factor *= self.electrical_series.channel_conversion[channel_index] + Args: + start: Start sample index (default: 0) + end: End sample index (default: num_samples) - return self.electrical_series.data[start:end, channel_index] * scale_factor + self.electrical_series.offset + Returns: + list of numpy arrays, one per channel, with scaling applied + """ + # Single HDF5 read for all channels + all_data = self.electrical_series.data[start:end, :] + + base_scale = self.electrical_series.conversion + offset = self.electrical_series.offset + + # Apply per-channel scaling if present + if self.electrical_series.channel_conversion is not None: + channel_scales = np.array(self.electrical_series.channel_conversion) * base_scale + # Broadcast multiply: (samples, channels) * (channels,) -> (samples, channels) + scaled_data = all_data * channel_scales + offset + else: + scaled_data = all_data * base_scale + offset + + # Split into list of per-channel arrays + return [scaled_data[:, i] for i in range(self.num_channels)] diff --git a/processor/writer.py b/processor/writer.py index fdd9573..3204e7c 100644 --- a/processor/writer.py +++ b/processor/writer.py @@ -2,7 +2,7 @@ import json import logging import os -from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import ThreadPoolExecutor import numpy as np from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION @@ -12,10 +12,9 @@ log = logging.getLogger() -def _write_channel_chunk_worker(chunk_data, start_time, end_time, channel_index, output_dir): +def _write_channel_chunk(chunk_data, start_time, end_time, channel_index, output_dir): """ - Worker function for parallel channel chunk processing. - Must be a top-level function to be picklable for ProcessPoolExecutor. + Write a single channel's chunk data to a gzipped file. Args: chunk_data: numpy array of sample data for the channel @@ -24,7 +23,6 @@ def _write_channel_chunk_worker(chunk_data, start_time, end_time, channel_index, channel_index: channel index for filename output_dir: directory to write output file """ - # Convert to big-endian format formatted_data = to_big_endian(chunk_data.astype(np.float64)) channel_index_str = "{index:05d}".format(index=channel_index) @@ -33,7 +31,7 @@ def _write_channel_chunk_worker(chunk_data, start_time, end_time, channel_index, ) file_path = os.path.join(output_dir, file_name) - with gzip.open(file_path, mode="wb", compresslevel=1) as f: + with gzip.open(file_path, mode="wb", compresslevel=0) as f: f.write(formatted_data) @@ -57,17 +55,17 @@ def write_electrical_series(self, electrical_series, max_workers=None): 2. Chunks each contiguous segment into the given chunk_size (number of samples to include per file) Writes each chunk to the given output directory. - Channel processing is parallelized using ProcessPoolExecutor for improved performance - with datasets containing many channels (64-384 typical in neuroscience). + Channel processing is parallelized using ThreadPoolExecutor. Threads share memory + (no serialization overhead) and the GIL is released during gzip compression and file I/O. Args: electrical_series: NWB ElectricalSeries object - max_workers: Maximum number of worker processes (defaults to CPU count) + max_workers: Maximum number of threads (defaults to min(32, cpu_count + 4)) """ reader = NWBElectricalSeriesReader(electrical_series, self.session_start_time) num_channels = len(reader.channels) - with ProcessPoolExecutor(max_workers=max_workers) as executor: + with ThreadPoolExecutor(max_workers=max_workers) as executor: for contiguous_start, contiguous_end in reader.contiguous_chunks(): for chunk_start in range(contiguous_start, contiguous_end, self.chunk_size): chunk_end = min(contiguous_end, chunk_start + self.chunk_size) @@ -75,16 +73,11 @@ def write_electrical_series(self, electrical_series, max_workers=None): start_time = reader.get_timestamp(chunk_start) end_time = reader.get_timestamp(chunk_end - 1) - # Read all channel data for this chunk at once from HDF5 - # (HDF5 doesn't support efficient concurrent reads) - channel_chunks = [ - reader.get_chunk(channel_index, chunk_start, chunk_end) for channel_index in range(num_channels) - ] + channel_chunks = reader.get_chunk(chunk_start, chunk_end) - # Submit all channels for parallel processing futures = [ executor.submit( - _write_channel_chunk_worker, + _write_channel_chunk, channel_chunks[i], start_time, end_time, @@ -94,7 +87,6 @@ def write_electrical_series(self, electrical_series, max_workers=None): for i in range(num_channels) ] - # Wait for all to complete for future in futures: future.result() @@ -103,21 +95,15 @@ def write_electrical_series(self, electrical_series, max_workers=None): def write_chunk(self, chunk, start_time, end_time, channel): """ - Formats the chunked sample data into 64-bit (8 byte) values in big-endian. - Writes the chunked sample data to a gzipped binary file. - """ - # ensure the samples are 64-bit float-pointing numbers in big-endian before converting to bytes - formatted_data = to_big_endian(chunk.astype(np.float64)) - channel_index = "{index:05d}".format(index=channel.index) - file_name = "channel-{}_{}_{}{}".format( - channel_index, int(start_time * 1e6), int(end_time * 1e6), TIME_SERIES_BINARY_FILE_EXTENSION - ) - file_path = os.path.join(self.output_dir, file_name) - - with gzip.open(file_path, mode="wb", compresslevel=1) as f: - f.write(formatted_data) + Args: + chunk: numpy array of sample data + start_time: start timestamp in seconds + end_time: end timestamp in seconds + channel: TimeSeriesChannel object + """ + _write_channel_chunk(chunk, start_time, end_time, channel.index, self.output_dir) def write_channel(self, channel): file_name = f"channel-{channel.index:05d}{TIME_SERIES_METADATA_FILE_EXTENSION}" diff --git a/tests/test_reader.py b/tests/test_reader.py index 19cb160..d49ca7b 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -60,7 +60,6 @@ def test_basic_initialization_with_rate(self): assert reader.num_samples == 1000 assert reader.num_channels == 4 assert reader.sampling_rate == 1000.0 - assert len(reader.timestamps) == 1000 def test_initialization_with_timestamps(self): """Test initialization with timestamps specified. @@ -114,38 +113,25 @@ def test_session_start_time_offset(self): # Timestamps should be offset by session_start_time_secs expected_start = session_start.timestamp() - assert reader.timestamps[0] == pytest.approx(expected_start, rel=1e-6) + assert reader.get_timestamp(0) == pytest.approx(expected_start, rel=1e-6) class TestSamplingRateAndTimestampComputation: - """Tests for _compute_sampling_rate_and_timestamps method.""" + """Tests for sampling rate and timestamp computation.""" - def test_rate_only_generates_timestamps(self): + def test_rate_only_generates_correct_timestamps(self): """Test that timestamps are generated from rate when only rate is provided.""" series = create_mock_electrical_series(1000, 2, rate=1000.0) session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) - # Should have 1000 timestamps spanning 1 second - assert len(reader.timestamps) == 1000 # First timestamp should be at session start - assert reader.timestamps[0] == pytest.approx(session_start.timestamp(), rel=1e-6) + assert reader.get_timestamp(0) == pytest.approx(session_start.timestamp(), rel=1e-6) # Time span should be ~1 second (1000 samples at 1000 Hz) - time_span = reader.timestamps[-1] - reader.timestamps[0] + time_span = reader.get_timestamp(999) - reader.get_timestamp(0) assert time_span == pytest.approx(0.999, rel=1e-3) - def test_rate_generates_correct_timestamps(self): - """Test that timestamps are generated correctly from rate.""" - series = create_mock_electrical_series(100, 2, rate=100.0) # 100 Hz - session_start = datetime(2023, 1, 1, 12, 0, 0) - - reader = NWBElectricalSeriesReader(series, session_start) - - # Timestamps should span 1 second (100 samples at 100 Hz) - time_span = reader.timestamps[-1] - reader.timestamps[0] - assert abs(time_span - 0.99) < 0.01 # Approximately 0.99 seconds - def test_rate_stored_correctly(self): """Test that rate is stored correctly.""" series = create_mock_electrical_series(100, 2, rate=30000.0) @@ -155,15 +141,6 @@ def test_rate_stored_correctly(self): assert reader.sampling_rate == 30000.0 - def test_timestamps_count_matches_samples(self): - """Test that number of timestamps matches number of samples.""" - series = create_mock_electrical_series(500, 3, rate=1000.0) - session_start = datetime(2023, 1, 1, 12, 0, 0) - - reader = NWBElectricalSeriesReader(series, session_start) - - assert len(reader.timestamps) == 500 - class TestChannelsProperty: """Tests for channels property.""" @@ -251,73 +228,73 @@ def test_chunk_boundaries_format(self): assert start < end -class TestGetChunk: - """Tests for get_chunk method.""" +class TestGetAllChannelsChunk: + """Tests for batch channel reading.""" - def test_get_full_channel_data(self): - """Test getting full channel data without start/end.""" - series = create_mock_electrical_series(10, 2, rate=1000.0) - # Set specific data values - series.data = np.arange(20).reshape(10, 2).astype(np.float64) + def test_returns_all_channels(self): + """Test that get_chunk returns data for all channels.""" + series = create_mock_electrical_series(100, 4, rate=1000.0) session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) - chunk = reader.get_chunk(0) # First channel + chunks = reader.get_chunk(0, 50) - np.testing.assert_array_equal(chunk, series.data[:, 0]) + assert len(chunks) == 4 + for chunk in chunks: + assert len(chunk) == 50 - def test_get_partial_channel_data(self): - """Test getting partial channel data with start/end.""" + def test_get_full_data(self): + """Test getting full data without start/end.""" series = create_mock_electrical_series(10, 2, rate=1000.0) series.data = np.arange(20).reshape(10, 2).astype(np.float64) session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) - chunk = reader.get_chunk(1, start=2, end=5) # Second channel, samples 2-5 + chunks = reader.get_chunk() - np.testing.assert_array_equal(chunk, series.data[2:5, 1]) + np.testing.assert_array_equal(chunks[0], series.data[:, 0]) + np.testing.assert_array_equal(chunks[1], series.data[:, 1]) - def test_conversion_factor_applied(self): - """Test that conversion factor is applied to data.""" - series = create_mock_electrical_series(10, 2, rate=1000.0, conversion=2.0) - series.data = np.ones((10, 2)) + def test_get_partial_data(self): + """Test getting partial data with start/end.""" + series = create_mock_electrical_series(10, 2, rate=1000.0) + series.data = np.arange(20).reshape(10, 2).astype(np.float64) session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) - chunk = reader.get_chunk(0) + chunks = reader.get_chunk(start=2, end=5) - np.testing.assert_array_equal(chunk, np.ones(10) * 2.0) + np.testing.assert_array_equal(chunks[0], series.data[2:5, 0]) + np.testing.assert_array_equal(chunks[1], series.data[2:5, 1]) - def test_offset_applied(self): - """Test that offset is applied to data.""" - series = create_mock_electrical_series(10, 2, rate=1000.0, offset=5.0) + def test_conversion_factor_applied(self): + """Test that conversion factor is applied in batch read.""" + series = create_mock_electrical_series(10, 2, rate=1000.0, conversion=2.0) series.data = np.ones((10, 2)) session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) - chunk = reader.get_chunk(0) + chunks = reader.get_chunk() - np.testing.assert_array_equal(chunk, np.ones(10) * 1.0 + 5.0) + for chunk in chunks: + np.testing.assert_array_equal(chunk, np.ones(10) * 2.0) def test_channel_conversion_applied(self): - """Test that per-channel conversion is applied.""" - channel_conversion = [2.0, 3.0] - series = create_mock_electrical_series(10, 2, rate=1000.0, channel_conversion=channel_conversion) - series.data = np.ones((10, 2)) + """Test that per-channel conversion is applied in batch read.""" + channel_conversion = [2.0, 3.0, 4.0] + series = create_mock_electrical_series(10, 3, rate=1000.0, channel_conversion=channel_conversion) + series.data = np.ones((10, 3)) session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) + chunks = reader.get_chunk() - chunk0 = reader.get_chunk(0) - chunk1 = reader.get_chunk(1) - - np.testing.assert_array_equal(chunk0, np.ones(10) * 2.0) - np.testing.assert_array_equal(chunk1, np.ones(10) * 3.0) + np.testing.assert_array_equal(chunks[0], np.ones(10) * 2.0) + np.testing.assert_array_equal(chunks[1], np.ones(10) * 3.0) + np.testing.assert_array_equal(chunks[2], np.ones(10) * 4.0) def test_all_scaling_factors_combined(self): - """Test that conversion, channel_conversion, and offset are all applied.""" - # Result should be: data * conversion * channel_conversion + offset - # = 1.0 * 2.0 * 3.0 + 1.0 = 7.0 + """Test that all scaling factors are applied in batch read.""" series = create_mock_electrical_series( 10, 2, rate=1000.0, conversion=2.0, channel_conversion=[3.0, 4.0], offset=1.0 ) @@ -325,6 +302,8 @@ def test_all_scaling_factors_combined(self): session_start = datetime(2023, 1, 1, 12, 0, 0) reader = NWBElectricalSeriesReader(series, session_start) - chunk = reader.get_chunk(0) + chunks = reader.get_chunk() - np.testing.assert_array_equal(chunk, np.ones(10) * 7.0) + # Result: data * conversion * channel_conversion + offset + np.testing.assert_array_equal(chunks[0], np.ones(10) * 7.0) # 1 * 2 * 3 + 1 = 7 + np.testing.assert_array_equal(chunks[1], np.ones(10) * 9.0) # 1 * 2 * 4 + 1 = 9 diff --git a/tests/test_writer.py b/tests/test_writer.py index dce9a1c..ff722c8 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -6,7 +6,7 @@ import numpy as np from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION from timeseries_channel import TimeSeriesChannel -from writer import TimeSeriesChunkWriter, _write_channel_chunk_worker +from writer import TimeSeriesChunkWriter, _write_channel_chunk class TestTimeSeriesChunkWriterInit: @@ -186,9 +186,8 @@ def test_write_electrical_series_single_chunk(self, temp_output_dir, session_sta # Create mock electrical series with 500 samples (less than chunk_size) mock_reader = Mock() mock_reader.channels = [TimeSeriesChannel(index=0, name="Ch0", rate=1000.0, start=0, end=1000)] - mock_reader.timestamps = np.linspace(0, 0.5, 500, endpoint=False) mock_reader.contiguous_chunks.return_value = [(0, 500)] - mock_reader.get_chunk.return_value = np.random.randn(500).astype(np.float64) + mock_reader.get_chunk.return_value = [np.random.randn(500).astype(np.float64)] mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): @@ -213,9 +212,11 @@ def test_write_electrical_series_multiple_chunks(self, temp_output_dir, session_ TimeSeriesChannel(index=0, name="Ch0", rate=1000.0, start=0, end=1000), TimeSeriesChannel(index=1, name="Ch1", rate=1000.0, start=0, end=1000), ] - mock_reader.timestamps = np.linspace(0, 0.25, 250, endpoint=False) mock_reader.contiguous_chunks.return_value = [(0, 250)] - mock_reader.get_chunk.return_value = np.random.randn(100).astype(np.float64) + mock_reader.get_chunk.return_value = [ + np.random.randn(100).astype(np.float64), + np.random.randn(100).astype(np.float64), + ] mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): @@ -238,11 +239,8 @@ def test_write_electrical_series_with_gap(self, temp_output_dir, session_start_t mock_reader = Mock() mock_reader.channels = [TimeSeriesChannel(index=0, name="Ch0", rate=1000.0, start=0, end=1000)] # Two contiguous segments - timestamps_seg1 = np.linspace(0, 0.1, 100, endpoint=False) - timestamps_seg2 = np.linspace(0.2, 0.3, 100, endpoint=False) - mock_reader.timestamps = np.concatenate([timestamps_seg1, timestamps_seg2]) mock_reader.contiguous_chunks.return_value = [(0, 100), (100, 200)] - mock_reader.get_chunk.return_value = np.random.randn(100).astype(np.float64) + mock_reader.get_chunk.return_value = [np.random.randn(100).astype(np.float64)] mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): @@ -263,9 +261,8 @@ def test_write_electrical_series_chunk_timestamps(self, temp_output_dir, session mock_reader.channels = [TimeSeriesChannel(index=0, name="Ch0", rate=1000.0, start=0, end=1000)] # 100 samples at 1000 Hz = 0.1 seconds timestamps = np.linspace(1.0, 1.1, 100, endpoint=False) - mock_reader.timestamps = timestamps mock_reader.contiguous_chunks.return_value = [(0, 100)] - mock_reader.get_chunk.return_value = np.random.randn(50).astype(np.float64) + mock_reader.get_chunk.return_value = [np.random.randn(50).astype(np.float64)] mock_reader.get_timestamp.side_effect = lambda idx: float(timestamps[idx]) with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): @@ -349,14 +346,14 @@ def test_write_chunk_special_float_values(self, temp_output_dir, session_start_t class TestParallelProcessing: """Tests for parallel channel processing functionality.""" - def test_write_channel_chunk_worker(self, temp_output_dir): - """Test the worker function directly.""" + def test_write_channel_chunk(self, temp_output_dir): + """Test the channel chunk write function directly.""" chunk_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float64) start_time = 1.0 end_time = 1.005 channel_index = 0 - _write_channel_chunk_worker(chunk_data, start_time, end_time, channel_index, temp_output_dir) + _write_channel_chunk(chunk_data, start_time, end_time, channel_index, temp_output_dir) # Check file was created expected_filename = ( @@ -371,10 +368,10 @@ def test_write_channel_chunk_worker(self, temp_output_dir): result = np.frombuffer(data, dtype=">f8") np.testing.assert_array_equal(result, [1.0, 2.0, 3.0, 4.0, 5.0]) - def test_write_channel_chunk_worker_big_endian(self, temp_output_dir): - """Test that worker function writes data in big-endian format.""" + def test_write_channel_chunk_big_endian(self, temp_output_dir): + """Test that chunk write function outputs big-endian format.""" chunk_data = np.array([1.5, -2.5], dtype=np.float64) - _write_channel_chunk_worker(chunk_data, 0.0, 0.001, 5, temp_output_dir) + _write_channel_chunk(chunk_data, 0.0, 0.001, 5, temp_output_dir) file_path = os.path.join(temp_output_dir, "channel-00005_0_1000.bin.gz") with gzip.open(file_path, "rb") as f: @@ -393,7 +390,7 @@ def test_parallel_processing_many_channels(self, temp_output_dir, session_start_ TimeSeriesChannel(index=i, name=f"Ch{i}", rate=30000.0, start=0, end=1000) for i in range(num_channels) ] mock_reader.contiguous_chunks.return_value = [(0, 500)] - mock_reader.get_chunk.return_value = np.random.randn(500).astype(np.float64) + mock_reader.get_chunk.return_value = [np.random.randn(500).astype(np.float64) for _ in range(num_channels)] mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 30000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): @@ -418,7 +415,7 @@ def test_parallel_processing_with_max_workers(self, temp_output_dir, session_sta TimeSeriesChannel(index=i, name=f"Ch{i}", rate=1000.0, start=0, end=1000) for i in range(8) ] mock_reader.contiguous_chunks.return_value = [(0, 500)] - mock_reader.get_chunk.return_value = np.random.randn(500).astype(np.float64) + mock_reader.get_chunk.return_value = [np.random.randn(500).astype(np.float64) for _ in range(8)] mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): @@ -438,14 +435,14 @@ def test_parallel_processing_data_integrity(self, temp_output_dir, session_start writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, chunk_size=100) # Create distinct data for each channel - channel_data = {i: np.arange(100, dtype=np.float64) + i * 1000 for i in range(num_channels)} + channel_data = [np.arange(100, dtype=np.float64) + i * 1000 for i in range(num_channels)] mock_reader = Mock() mock_reader.channels = [ TimeSeriesChannel(index=i, name=f"Ch{i}", rate=1000.0, start=0, end=1000) for i in range(num_channels) ] mock_reader.contiguous_chunks.return_value = [(0, 100)] - mock_reader.get_chunk.side_effect = lambda ch_idx, start, end: channel_data[ch_idx] + mock_reader.get_chunk.return_value = channel_data mock_reader.get_timestamp.side_effect = lambda idx: float(idx) / 1000.0 with patch("writer.NWBElectricalSeriesReader", return_value=mock_reader): From 01d31df454f08144e16d466b3539253b9b6ddad1 Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Wed, 17 Dec 2025 11:41:07 -0500 Subject: [PATCH 08/10] cleanup --- processor/writer.py | 49 ++++++++++------------- tests/test_writer.py | 95 +++++++++----------------------------------- 2 files changed, 39 insertions(+), 105 deletions(-) diff --git a/processor/writer.py b/processor/writer.py index 3204e7c..589d3ae 100644 --- a/processor/writer.py +++ b/processor/writer.py @@ -12,29 +12,6 @@ log = logging.getLogger() -def _write_channel_chunk(chunk_data, start_time, end_time, channel_index, output_dir): - """ - Write a single channel's chunk data to a gzipped file. - - Args: - chunk_data: numpy array of sample data for the channel - start_time: start timestamp in seconds - end_time: end timestamp in seconds - channel_index: channel index for filename - output_dir: directory to write output file - """ - formatted_data = to_big_endian(chunk_data.astype(np.float64)) - - channel_index_str = "{index:05d}".format(index=channel_index) - file_name = "channel-{}_{}_{}{}".format( - channel_index_str, int(start_time * 1e6), int(end_time * 1e6), TIME_SERIES_BINARY_FILE_EXTENSION - ) - file_path = os.path.join(output_dir, file_name) - - with gzip.open(file_path, mode="wb", compresslevel=0) as f: - f.write(formatted_data) - - class TimeSeriesChunkWriter: """ Attributes: @@ -77,7 +54,7 @@ def write_electrical_series(self, electrical_series, max_workers=None): futures = [ executor.submit( - _write_channel_chunk, + self.write_chunk, channel_chunks[i], start_time, end_time, @@ -93,17 +70,33 @@ def write_electrical_series(self, electrical_series, max_workers=None): for channel in reader.channels: self.write_channel(channel) - def write_chunk(self, chunk, start_time, end_time, channel): + @staticmethod + def write_chunk(chunk, start_time, end_time, channel_index, output_dir): """ + Formats the chunked sample data into 64-bit (8 byte) values in big-endian. + Writes the chunked sample data to a gzipped binary file. Args: - chunk: numpy array of sample data + chunk: numpy array of sample data for the channel start_time: start timestamp in seconds end_time: end timestamp in seconds - channel: TimeSeriesChannel object + channel_index: channel index for output filename + output_dir: directory to write chunked output file """ - _write_channel_chunk(chunk, start_time, end_time, channel.index, self.output_dir) + # ensure the samples are 64-bit float-pointing numbers in big-endian before converting to bytes + formatted_data = to_big_endian(chunk.astype(np.float64)) + + file_name = "channel-{}_{}_{}{}".format( + "{index:05d}".format(index=channel_index), + int(start_time * 1e6), + int(end_time * 1e6), + TIME_SERIES_BINARY_FILE_EXTENSION, + ) + file_path = os.path.join(output_dir, file_name) + + with gzip.open(file_path, mode="wb", compresslevel=0) as f: + f.write(formatted_data) def write_channel(self, channel): file_name = f"channel-{channel.index:05d}{TIME_SERIES_METADATA_FILE_EXTENSION}" diff --git a/tests/test_writer.py b/tests/test_writer.py index ff722c8..9578459 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -6,7 +6,7 @@ import numpy as np from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION from timeseries_channel import TimeSeriesChannel -from writer import TimeSeriesChunkWriter, _write_channel_chunk +from writer import TimeSeriesChunkWriter class TestTimeSeriesChunkWriterInit: @@ -24,19 +24,15 @@ def test_initialization(self, temp_output_dir, session_start_time): class TestWriteChunk: - """Tests for write_chunk method.""" + """Tests for write_chunk static method.""" - def test_write_chunk_creates_file(self, temp_output_dir, session_start_time): + def test_write_chunk_creates_file(self, temp_output_dir): """Test that write_chunk creates a binary file.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - chunk = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float64) - channel = TimeSeriesChannel(index=0, name="Test Channel", rate=1000.0, start=1000000, end=2000000) - start_time = 1.0 end_time = 1.005 - writer.write_chunk(chunk, start_time, end_time, channel) + TimeSeriesChunkWriter.write_chunk(chunk, start_time, end_time, 0, temp_output_dir) # Check file was created expected_filename = ( @@ -45,14 +41,11 @@ def test_write_chunk_creates_file(self, temp_output_dir, session_start_time): file_path = os.path.join(temp_output_dir, expected_filename) assert os.path.exists(file_path) - def test_write_chunk_gzip_compressed(self, temp_output_dir, session_start_time): + def test_write_chunk_gzip_compressed(self, temp_output_dir): """Test that output file is gzip compressed.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - chunk = np.array([1.0, 2.0, 3.0], dtype=np.float64) - channel = TimeSeriesChannel(index=0, name="Test", rate=1000.0, start=0, end=1000) - writer.write_chunk(chunk, 1.0, 1.003, channel) + TimeSeriesChunkWriter.write_chunk(chunk, 1.0, 1.003, 0, temp_output_dir) # Find the file (timestamps may vary slightly) files = [f for f in os.listdir(temp_output_dir) if f.endswith(".bin.gz")] @@ -64,14 +57,11 @@ def test_write_chunk_gzip_compressed(self, temp_output_dir, session_start_time): data = f.read() assert len(data) > 0 - def test_write_chunk_big_endian_format(self, temp_output_dir, session_start_time): + def test_write_chunk_big_endian_format(self, temp_output_dir): """Test that data is written in big-endian format.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - chunk = np.array([1.0, 2.0, 3.0], dtype=np.float64) - channel = TimeSeriesChannel(index=0, name="Test", rate=1000.0, start=0, end=1000) - writer.write_chunk(chunk, 1.0, 1.003, channel) + TimeSeriesChunkWriter.write_chunk(chunk, 1.0, 1.003, 0, temp_output_dir) # Find the file files = [f for f in os.listdir(temp_output_dir) if f.endswith(".bin.gz")] @@ -85,32 +75,26 @@ def test_write_chunk_big_endian_format(self, temp_output_dir, session_start_time result = np.frombuffer(data, dtype=">f8") np.testing.assert_array_equal(result, [1.0, 2.0, 3.0]) - def test_write_chunk_channel_index_formatting(self, temp_output_dir, session_start_time): + def test_write_chunk_channel_index_formatting(self, temp_output_dir): """Test that channel index is zero-padded to 5 digits.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - chunk = np.array([1.0], dtype=np.float64) # Test various channel indices with unique timestamps to avoid overwriting for i, index in enumerate([0, 5, 42, 999, 12345]): - channel = TimeSeriesChannel(index=index, name="Test", rate=1000.0, start=0, end=1000) start_time = 1.0 + i * 0.1 end_time = start_time + 0.001 - writer.write_chunk(chunk, start_time, end_time, channel) + TimeSeriesChunkWriter.write_chunk(chunk, start_time, end_time, index, temp_output_dir) # Check that file with correct channel index prefix exists files = [f for f in os.listdir(temp_output_dir) if f.startswith(f"channel-{index:05d}_")] assert len(files) >= 1, f"No file found for channel index {index}" - def test_write_chunk_preserves_data_precision(self, temp_output_dir, session_start_time): + def test_write_chunk_preserves_data_precision(self, temp_output_dir): """Test that float64 precision is preserved.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - # Use values that require float64 precision chunk = np.array([1.123456789012345, -9.87654321098765e10, 1e-15], dtype=np.float64) - channel = TimeSeriesChannel(index=0, name="Test", rate=1000.0, start=0, end=1000) - writer.write_chunk(chunk, 1.0, 1.003, channel) + TimeSeriesChunkWriter.write_chunk(chunk, 1.0, 1.003, 0, temp_output_dir) # Find the file files = [f for f in os.listdir(temp_output_dir) if f.endswith(".bin.gz")] @@ -284,14 +268,11 @@ def test_write_electrical_series_chunk_timestamps(self, temp_output_dir, session class TestWriteChunkEdgeCases: """Edge case tests for chunk writing.""" - def test_write_empty_chunk(self, temp_output_dir, session_start_time): + def test_write_empty_chunk(self, temp_output_dir): """Test writing an empty chunk.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - chunk = np.array([], dtype=np.float64) - channel = TimeSeriesChannel(index=0, name="Test", rate=1000.0, start=0, end=1000) - writer.write_chunk(chunk, 1.0, 1.0, channel) + TimeSeriesChunkWriter.write_chunk(chunk, 1.0, 1.0, 0, temp_output_dir) file_path = os.path.join(temp_output_dir, "channel-00000_1000000_1000000.bin.gz") @@ -300,15 +281,12 @@ def test_write_empty_chunk(self, temp_output_dir, session_start_time): assert len(data) == 0 - def test_write_large_chunk(self, temp_output_dir, session_start_time): + def test_write_large_chunk(self, temp_output_dir): """Test writing a large chunk.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - # 1 million samples chunk = np.random.randn(1000000).astype(np.float64) - channel = TimeSeriesChannel(index=0, name="Test", rate=1000.0, start=0, end=1000) - writer.write_chunk(chunk, 0.0, 1000.0, channel) + TimeSeriesChunkWriter.write_chunk(chunk, 0.0, 1000.0, 0, temp_output_dir) file_path = os.path.join(temp_output_dir, "channel-00000_0_1000000000.bin.gz") assert os.path.exists(file_path) @@ -320,14 +298,11 @@ def test_write_large_chunk(self, temp_output_dir, session_start_time): result = np.frombuffer(data, dtype=">f8") assert len(result) == 1000000 - def test_write_chunk_special_float_values(self, temp_output_dir, session_start_time): + def test_write_chunk_special_float_values(self, temp_output_dir): """Test writing chunks with special float values.""" - writer = TimeSeriesChunkWriter(session_start_time, temp_output_dir, 1000) - chunk = np.array([np.inf, -np.inf, np.nan, 0.0, -0.0], dtype=np.float64) - channel = TimeSeriesChannel(index=0, name="Test", rate=1000.0, start=0, end=1000) - writer.write_chunk(chunk, 1.0, 1.005, channel) + TimeSeriesChunkWriter.write_chunk(chunk, 1.0, 1.005, 0, temp_output_dir) # Find the file files = [f for f in os.listdir(temp_output_dir) if f.endswith(".bin.gz")] @@ -346,40 +321,6 @@ def test_write_chunk_special_float_values(self, temp_output_dir, session_start_t class TestParallelProcessing: """Tests for parallel channel processing functionality.""" - def test_write_channel_chunk(self, temp_output_dir): - """Test the channel chunk write function directly.""" - chunk_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float64) - start_time = 1.0 - end_time = 1.005 - channel_index = 0 - - _write_channel_chunk(chunk_data, start_time, end_time, channel_index, temp_output_dir) - - # Check file was created - expected_filename = ( - f"channel-00000_{int(start_time * 1e6)}_{int(end_time * 1e6)}{TIME_SERIES_BINARY_FILE_EXTENSION}" - ) - file_path = os.path.join(temp_output_dir, expected_filename) - assert os.path.exists(file_path) - - # Verify data integrity - with gzip.open(file_path, "rb") as f: - data = f.read() - result = np.frombuffer(data, dtype=">f8") - np.testing.assert_array_equal(result, [1.0, 2.0, 3.0, 4.0, 5.0]) - - def test_write_channel_chunk_big_endian(self, temp_output_dir): - """Test that chunk write function outputs big-endian format.""" - chunk_data = np.array([1.5, -2.5], dtype=np.float64) - _write_channel_chunk(chunk_data, 0.0, 0.001, 5, temp_output_dir) - - file_path = os.path.join(temp_output_dir, "channel-00005_0_1000.bin.gz") - with gzip.open(file_path, "rb") as f: - data = f.read() - - result = np.frombuffer(data, dtype=">f8") - np.testing.assert_array_equal(result, [1.5, -2.5]) - def test_parallel_processing_many_channels(self, temp_output_dir, session_start_time): """Test parallel processing with many channels (typical neuroscience scenario).""" num_channels = 64 From 1cf870f9fcfbbd9afef48fe4c442a55a0c00495c Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Wed, 17 Dec 2025 11:47:29 -0500 Subject: [PATCH 09/10] cleanup explicit timestamps --- processor/reader.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/processor/reader.py b/processor/reader.py index 903e367..03bc2fa 100644 --- a/processor/reader.py +++ b/processor/reader.py @@ -35,11 +35,14 @@ def __init__(self, electrical_series, session_start_time): ), "Electrode channels do not align with data shape" self._sampling_rate = None - self._has_explicit_timestamps = False self._compute_sampling_rate() self._channels = None + @property + def has_explicit_timestamps(self): + return self.electrical_series.timestamps is not None + def _compute_sampling_rate(self): """ Computes and stores the sampling rate. @@ -56,8 +59,7 @@ def _compute_sampling_rate(self): # if both the timestamps and rate properties are set on the electrical # series validate that the given rate is within a 2% margin of the # sampling rate calculated off of the given timestamps - if self.electrical_series.rate and self.electrical_series.timestamps is not None: - self._has_explicit_timestamps = True + if self.electrical_series.rate and self.has_explicit_timestamps: sampling_rate = self.electrical_series.rate sample_size = min(10000, self.num_samples) @@ -76,11 +78,9 @@ def _compute_sampling_rate(self): # if only the rate is given, timestamps will be computed on-demand elif self.electrical_series.rate: self._sampling_rate = self.electrical_series.rate - self._has_explicit_timestamps = False # if only the timestamps are given, calculate the sampling rate using a sample of timestamps - elif self.electrical_series.timestamps is not None: - self._has_explicit_timestamps = True + elif self.has_explicit_timestamps: sample_size = min(10000, self.num_samples) sample_timestamps = self.electrical_series.timestamps[:sample_size] self._sampling_rate = round(infer_sampling_rate(sample_timestamps)) @@ -92,7 +92,7 @@ def get_timestamp(self, index): """ timestamp = ( float(self.electrical_series.timestamps[index]) - if self._has_explicit_timestamps + if self.has_explicit_timestamps else (index / self._sampling_rate) ) return timestamp + self.session_start_time_secs @@ -105,7 +105,7 @@ def get_timestamps(self, start, end): """ timestamps = ( np.array(self.electrical_series.timestamps[start:end]) - if self._has_explicit_timestamps + if self.has_explicit_timestamps else np.linspace( start / self._sampling_rate, end / self._sampling_rate, @@ -171,7 +171,7 @@ def contiguous_chunks(self): (timestamp_difference) > 2 * sampling_period """ # if no explicit timestamps, data is continuous by definition - if not self._has_explicit_timestamps: + if not self.has_explicit_timestamps: yield 0, self.num_samples return From 87d6dee029652f3ebbd174b96293f7c8c4f249f4 Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Wed, 17 Dec 2025 11:50:56 -0500 Subject: [PATCH 10/10] re-add assertion --- processor/reader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/processor/reader.py b/processor/reader.py index 03bc2fa..e427fb7 100644 --- a/processor/reader.py +++ b/processor/reader.py @@ -37,6 +37,11 @@ def __init__(self, electrical_series, session_start_time): self._sampling_rate = None self._compute_sampling_rate() + if self.has_explicit_timestamps: + assert self.num_samples == len( + self.electrical_series.timestamps + ), "Differing number of sample and timestamp value" + self._channels = None @property