Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions processor/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_client import SessionManager as SessionManager
from .import_client import ImportClient as ImportClient
from .import_client import ImportFile as ImportFile
from .packages_client import PackagesClient as PackagesClient
from .timeseries_client import TimeSeriesClient as TimeSeriesClient
from .workflow_client import WorkflowClient as WorkflowClient
from .workflow_client import WorkflowInstance as WorkflowInstance
106 changes: 106 additions & 0 deletions processor/clients/packages_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import json
import logging

import requests

from .base_client import BaseClient

log = logging.getLogger()


class PackagesClient(BaseClient):
def __init__(self, api_host, session_manager):
super().__init__(session_manager)

self.api_host = api_host

@BaseClient.retry_with_refresh
def get_parent_package_id(self, package_id: str) -> str:
"""
Get the parent package ID for a given package.

Args:
package_id: The package ID to query

Returns:
str: The parent node ID

Raises:
requests.HTTPError: If the API request fails
"""
url = f"{self.api_host}/packages/{package_id}?includeAncestors=true&startAtEpoch=false&limit=100&offset=0"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {self.session_manager.session_token}",
}

try:
log.info(f"Fetching parent package ID for package: {package_id}")
response = requests.get(url, headers=headers)
response.raise_for_status()
package_info = response.json()
parent_node_id = package_info["parent"]["content"]["nodeId"]
return parent_node_id
except requests.HTTPError as e:
log.error(f"failed to get parent package ID for {package_id}: {e}")
raise e
except json.JSONDecodeError as e:
log.error(f"failed to decode package response: {e}")
raise e
except Exception as e:
log.error(f"failed to get parent package ID: {e}")
raise e

@BaseClient.retry_with_refresh
def update_properties(self, package_id: str, properties: list[dict]) -> None:
"""
Updates a package's properties on the Pennsieve API.

Args:
package_id: The package (node) ID
properties: List of property dicts with keys: key, value, dataType, category, fixed, hidden
"""
url = f"{self.api_host}/packages/{package_id}?updateStorage=true"

payload = {"properties": properties}

headers = {
"accept": "*/*",
"content-type": "application/json",
"Authorization": f"Bearer {self.session_manager.session_token}",
}

try:
response = requests.put(url, json=payload, headers=headers)
response.raise_for_status()
return None
except Exception as e:
log.error(f"failed to update package {package_id} properties: {e}")
raise e

def set_timeseries_properties(self, package_id: str) -> None:
"""
Sets the time series viewer properties on a package.

Args:
package_id: The package (node) ID
"""
properties = [
{
"key": "subtype",
"value": "pennsieve_timeseries",
"dataType": "string",
"category": "Viewer",
"fixed": False,
"hidden": True,
},
{
"key": "icon",
"value": "timeseries",
"dataType": "string",
"category": "Pennsieve",
"fixed": False,
"hidden": True,
},
]
return self.update_properties(package_id, properties)
1 change: 0 additions & 1 deletion processor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def __init__(self):
self.INPUT_DIR = os.getenv("INPUT_DIR")
self.OUTPUT_DIR = os.getenv("OUTPUT_DIR")


self.CHUNK_SIZE_MB = int(os.getenv("CHUNK_SIZE_MB", "1"))

# continue to use INTEGRATION_ID environment variable until runner
Expand Down
67 changes: 62 additions & 5 deletions processor/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Lock, Value
from typing import Optional

import backoff
import requests
from clients import AuthenticationClient, ImportClient, ImportFile, SessionManager, TimeSeriesClient, WorkflowClient
from clients import (
AuthenticationClient,
ImportClient,
ImportFile,
PackagesClient,
SessionManager,
TimeSeriesClient,
WorkflowClient,
)
from constants import TIME_SERIES_BINARY_FILE_EXTENSION, TIME_SERIES_METADATA_FILE_EXTENSION
from timeseries_channel import TimeSeriesChannel

Expand Down Expand Up @@ -49,10 +58,15 @@ def import_timeseries(api_host, api2_host, api_key, api_secret, workflow_instanc
workflow_client = WorkflowClient(api2_host, session_manager)
workflow_instance = workflow_client.get_workflow_instance(workflow_instance_id)

# constraint until we implement (upstream) performing imports over directories
# and specifying how to group time series files together into an imported package
assert len(workflow_instance.package_ids) == 1, "NWB post processor only supports a single package for import"
package_id = workflow_instance.package_ids[0]
# fetch the target package for channel data and time series properties
packages_client = PackagesClient(api_host, session_manager)
package_id = determine_target_package(packages_client, workflow_instance.package_ids)
if not package_id:
log.error("dataset_id={workflow_instance.dataset_id} could not determine target time series package")
return None

packages_client.set_timeseries_properties(package_id)
log.info(f"updated package {package_id} with time series properties")

log.info(f"dataset_id={workflow_instance.dataset_id} package_id={package_id} starting import of time series files")

Expand Down Expand Up @@ -140,3 +154,46 @@ def upload_timeseries_file(timeseries_file):
log.info(f"import_id={import_id} uploaded {upload_counter.value} time series files")

assert sum(successful_uploads) == len(import_files), "Failed to upload all time series files"


def determine_target_package(packages_client: PackagesClient, package_ids: list[str]) -> Optional[str]:
"""
Determine which package should receive the time series data and properties.

If there's only one package ID, use that package directly.
If there are multiple package IDs, find the first one with 'N:package:' prefix
and get its parent package ID.

Args:
packages_client: PackagesClient instance for API calls
package_ids: List of package IDs from the workflow instance

Returns:
The package ID to update with properties, or None if unable to determine
"""
if not package_ids:
log.warning("No package IDs provided")
return None

if len(package_ids) == 1:
log.info("Single package ID found, using it directly: %s", package_ids[0])
return package_ids[0]

first_package = None
for package_id in package_ids:
if package_id.startswith("N:package:"):
first_package = package_id
break

if first_package is None:
log.warning("No package ID with 'N:package:' prefix found in: %s", package_ids)
return None

log.info("Multiple package IDs found, getting parent of first package: %s", first_package)
try:
parent_id = packages_client.get_parent_package_id(first_package)
log.info("Parent package ID: %s", parent_id)
return parent_id
except Exception as e:
log.error("Failed to get parent package ID: %s", e)
return None
13 changes: 9 additions & 4 deletions processor/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,19 @@ def __init__(self, electrical_series, session_start_time):
assert (
len(self.electrical_series.electrodes.table) == self.num_channels
), "Electrode channels do not align with data shape"
log.info(f"NWB file has {self.num_samples} samples")

self._sampling_rate = None
self._compute_sampling_rate()
log.info(f"NWB file has sampling rate: {self.sampling_rate} Hz")

if self.has_explicit_timestamps:
log.info("NWB file has explicit timestamps")
assert self.num_samples == len(
self.electrical_series.timestamps
), "Differing number of sample and timestamp value"
else:
log.info("NWB file has implicit timestamps")

self._channels = None

Expand All @@ -68,7 +73,7 @@ def _compute_sampling_rate(self):
TimeSeries objects but its worth handling this case by validating the
sampling_rate against the timestamps if this case does somehow appear.
"""
if self.electrical_series.rate is None and self.electrical_series.timestamps is None:
if self.electrical_series.rate is None and not self.has_explicit_timestamps:
raise Exception("electrical series has no defined sampling rate or timestamp values")

# if both the timestamps and rate properties are set on the electrical
Expand All @@ -78,7 +83,7 @@ def _compute_sampling_rate(self):
sampling_rate = self.electrical_series.rate

sample_size = min(10000, self.num_samples)
sample_timestamps = self.electrical_series.timestamps[:sample_size]
sample_timestamps = self.get_timestamps(0, sample_size)
inferred_sampling_rate = infer_sampling_rate(sample_timestamps)

error = abs(inferred_sampling_rate - sampling_rate) * (1.0 / sampling_rate)
Expand All @@ -97,7 +102,7 @@ def _compute_sampling_rate(self):
# if only the timestamps are given, calculate the sampling rate using a sample of timestamps
elif self.has_explicit_timestamps:
sample_size = min(10000, self.num_samples)
sample_timestamps = self.electrical_series.timestamps[:sample_size]
sample_timestamps = self.get_timestamps(0, sample_size)
self._sampling_rate = round(infer_sampling_rate(sample_timestamps))

def get_timestamp(self, index):
Expand Down Expand Up @@ -199,7 +204,7 @@ def contiguous_chunks(self):

for batch_start in range(0, self.num_samples, batch_size):
batch_end = min(batch_start + batch_size, self.num_samples)
batch_timestamps = self.electrical_series.timestamps[batch_start:batch_end]
batch_timestamps = self.get_timestamps(batch_start, batch_end)

# check gap between batches
if prev_timestamp is not None:
Expand Down
4 changes: 2 additions & 2 deletions processor/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def write_chunk(chunk, start_time, end_time, channel_index, output_dir):

file_name = "channel-{}_{}_{}{}".format(
"{index:05d}".format(index=channel_index),
round(start_time * 1e6),
round(end_time * 1e6),
int(start_time * 1e6),
int(end_time * 1e6),
TIME_SERIES_BINARY_FILE_EXTENSION,
)
file_path = os.path.join(output_dir, file_name)
Expand Down
Loading