Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions caveclient/annotationengine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import json
import time
from typing import Dict, Iterable, List, Mapping, Optional, Union

import numpy as np
import pandas as pd

try:
import tqdm
except ImportError:
tqdm = None

from .auth import AuthClient
from .base import BaseEncoder, ClientBase, _api_endpoints, handle_response
from .endpoints import annotation_api_versions, annotation_common
Expand Down Expand Up @@ -800,6 +806,9 @@ def upload_staged_annotations(
self,
staged_annos: stage.StagedAnnotations,
aligned_volume_name: Optional[str] = None,
batch_size: int = 10_000,
progress: bool = True,
retries: int = 3,
) -> Union[list[int], dict[int, int]]:
"""
Upload annotations directly from an Annotation Guide object.
Expand All @@ -811,7 +820,13 @@ def upload_staged_annotations(
AnnotationGuide object with a specified table name and a collection of annotations already filled in.
aligned_volume_name : str or None, optional
Name of the aligned_volume. If None, uses the one specified in the client.

batch_size : int, optional
If the number of annotations exceeds this batch size, the upload will be split into multiple requests, by default 10,000.
progress : bool, optional
Whether to show a progress bar during upload, by default True.
retries : int, optional
Number of times to retry a batch if it fails, by default 3. Will sleep for 2^n seconds between retries, where n is the number of attempts so far.

Returns
-------
List or dict
Expand All @@ -823,14 +838,36 @@ def upload_staged_annotations(
"Only annotation guide objects with a specified table name can be used here"
)
if staged_annos.is_update:
return self.update_annotation(
staged_annos.table_name,
staged_annos.annotation_list,
aligned_volume_name=aligned_volume_name,
)
upload_function = self.update_annotation
ids_all = {} # type: ignore
else:
return self.post_annotation(
staged_annos.table_name,
staged_annos.annotation_list,
aligned_volume_name=aligned_volume_name,
)
upload_function = self.post_annotation
ids_all = [] # type: ignore

batches = staged_annos._annotation_batches(batch_size)

if tqdm is not None:
progress_ = tqdm.tqdm(batches, desc="Annotation Batches") if progress else batches
else:
progress_ = batches
for batch in progress_:
attempts = 0
while attempts < retries:
try:
batch_ids = upload_function(
staged_annos.table_name,
[staged_annos._process_annotation(a) for a in batch],
aligned_volume_name=aligned_volume_name,
)
staged_annos._apply_upload_result(batch, batch_ids)
if staged_annos.is_update:
ids_all.update(batch_ids) # type: ignore
else:
ids_all.extend(batch_ids)
break
except Exception as e:
attempts += 1
if attempts >= retries:
raise e
time.sleep(2 ** attempts) # Exponential backoff
return ids_all
184 changes: 156 additions & 28 deletions caveclient/tools/stage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import warnings

import attrs
import jsonschema
import numpy as np
import pandas as pd

SPATIAL_POINT_CLASSES = ["SpatialPoint", "BoundSpatialPoint"]

ADD_FUNC_DOCSTSRING = (
ADD_FUNC_DOCSTRING = (
"Add annotation to a local collection. Note that this does not upload annotations."
)


class StagedAnnotations(object):
IS_UPLOADED_FIELD = "_IS_UPLOADED_"
UPLOADED_ID_FIELD = "_UPLOADED_ID_"
IS_UPLOADED_COLUMN = "is_uploaded"
NEW_ID_COLUMN = "new_id"

def __init__(
self,
schema,
Expand All @@ -31,7 +38,16 @@ def __init__(
name : _type_, optional
_description_, by default None
id_field : bool, optional
_description_, by default False
Name of the id field, by default False
update : bool, optional
Whether these annotations are intended to update existing annotations (True) or create new annotations (False). If True, an "id" field will be added to the annotation class, by default False
table_resolution : list, optional
Resolution of the table that these annotations will be uploaded to, in units of nm/px
annotation_resolution : list, optional
Resolution of the annotations being added, in units of nm/px. If table_resolution is also provided, annotation coordinates will be automatically scaled to match table resolution. If not provided, coordinates will
be added as-is, and it is the user's responsibility to ensure they are in the correct units.
table_name : str, optional
Name of the table that these annotations will be uploaded to. If not provided, it is the user's responsibility to ensure that the schema provided matches the table they intend to upload to and that it is uploaded to the correct table.
"""
self._schema = schema
if update:
Expand All @@ -50,7 +66,7 @@ def __init__(
for x, y in zip(self._table_resolution, self._annotation_resolution)
]
elif self._annotation_resolution:
raise Warning(
warnings.warn(
"No table resolution set. Coordinates cannot be scaled automatically."
)

Expand Down Expand Up @@ -82,9 +98,9 @@ def __init__(
self.add = self._make_anno_func(
id_field=self._id_field, mixin=(self._build_mixin(),)
)
self.add.__doc__ = ADD_FUNC_DOCSTSRING
self.add.__doc__ = ADD_FUNC_DOCSTRING

def __repr__(self):
def __repr__(self) -> str:
if self._update:
update = "updated"
else:
Expand All @@ -94,12 +110,19 @@ def __repr__(self):
table_text = f"table '{self.table_name}'"
else:
table_text = f"schema '{self._ref_class}' with no table"
return f"Staged annotations for {table_text} ({len(self)} {update} annotations)"
n_total = len(self)
n_uploaded = sum(
1 for a in self._anno_list if getattr(a, self.IS_UPLOADED_FIELD, False)
)
return (
f"Staged annotations for {table_text} "
f"({n_total} {update} annotations, {n_uploaded} uploaded)"
)

def __len__(self):
def __len__(self) -> int:
return len(self._anno_list)

def add_dataframe(self, df):
def add_dataframe(self, df) -> None:
"""Add multiple annotations via a dataframe. Note that dataframe columns must exactly match fields in the schema (see the "fields" property to check)

Parameters
Expand Down Expand Up @@ -133,53 +156,155 @@ def add_dataframe(self, df):
self.add(**anno)

@property
def table_name(self):
def table_name(self) -> str:
return self._table_name

@table_name.setter
def table_name(self, x):
def table_name(self, x: str) -> None:
self._table_name = x

@property
def is_update(self):
def is_update(self) -> bool:
return self._update

@property
def fields(self):
def fields(self) -> list:
if self._id_field:
return ["id"] + self._prop_names
else:
return self._prop_names

@property
def fields_required(self):
def fields_required(self) -> list:
if self._id_field:
return ["id"] + self._name_positions_required()
else:
return self._name_positions_required()

@property
def annotation_list(self):
def annotation_list(self) -> list:
return [self._process_annotation(a, flat=False) for a in self._anno_list]

@property
def annotation_dataframe(self):
return pd.DataFrame.from_records(
[self._process_annotation(a, flat=True) for a in self._anno_list],
)
def annotation_list_nonuploaded(self) -> list:
return [
self._process_annotation(a, flat=False)
for a in self._anno_list
if not getattr(a, self.IS_UPLOADED_FIELD, False)
]

def _annotation_batches(self, batch_size) -> list:
"""Split the non-uploaded annotations into batches.

Parameters
----------
batch_size : int
The number of annotations to include in each batch.

Returns
-------
list
A list of batches, where each batch is a list of annotations that
have not yet been uploaded.
"""
nonuploaded = [a for a in self._anno_list if not getattr(a, self.IS_UPLOADED_FIELD, False)]
return [nonuploaded[i : i + batch_size] for i in range(0, len(nonuploaded), batch_size)]

def _apply_upload_result(self, batch, batch_ids) -> None:
"""Stamp a batch of annotations with the server's response.

For update stages, ``batch_ids`` is a ``{old_id: new_id}`` mapping with
string-keyed old ids (as returned over JSON); each annotation's new id
is looked up by ``str(a.id)``. For new-annotation stages, ``batch_ids``
is a list of server-assigned ids in batch order.
"""
if self.is_update:
for a in batch:
new_id = batch_ids[str(a.id)]
setattr(a, self.UPLOADED_ID_FIELD, new_id)
setattr(a, self.IS_UPLOADED_FIELD, True)
else:
assert len(batch_ids) == len(batch), (
f"Server returned {len(batch_ids)} ids for batch of size {len(batch)}"
)
for a, new_id in zip(batch, batch_ids):
setattr(a, self.UPLOADED_ID_FIELD, new_id)
setattr(a, self.IS_UPLOADED_FIELD, True)

def clear_annotations(self):
def annotation_dataframe(
self,
only_nonuploaded: bool = False,
include_tracking: bool = False,
) -> pd.DataFrame:
"""Get a dataframe of staged annotations.

Parameters
----------
only_nonuploaded : bool, optional
If True, only include annotations that have not been uploaded yet.
By default False.
include_tracking : bool, optional
If True, include two extra upload-tracking columns: ``is_uploaded``
(bool) and ``new_id`` (nullable Int64, the id assigned by the
server). For update-mode stages this preserves the old→new id
mapping: the ``id`` column holds the original id, ``new_id`` holds
the id returned by the server. By default False.

Returns
-------
pd.DataFrame
One row per annotation, with spatial point fields flattened into
``<name>_position`` columns.
"""
annos = self._anno_list
if only_nonuploaded:
annos = [a for a in annos if not getattr(a, self.IS_UPLOADED_FIELD, False)]
records = [
self._process_annotation(
a,
flat=True,
pop_is_uploaded=not include_tracking,
pop_uploaded_id=not include_tracking,
)
for a in annos
]
df = pd.DataFrame.from_records(records)
if include_tracking:
df = df.rename(
columns={
self.IS_UPLOADED_FIELD: self.IS_UPLOADED_COLUMN,
self.UPLOADED_ID_FIELD: self.NEW_ID_COLUMN,
}
)
if self.NEW_ID_COLUMN in df.columns:
df[self.NEW_ID_COLUMN] = df[self.NEW_ID_COLUMN].astype("Int64")
return df

def clear_annotations(self) -> None:
"""
Clear all annotations from the internal annotation list. Use with caution, as this cannot be undone.
"""
self._anno_list = []

def purge_uploaded_annotations(self) -> None:
"""
Remove annotations that have been uploaded from the internal annotation list.
"""
self._anno_list = [a for a in self._anno_list if not getattr(a, self.IS_UPLOADED_FIELD, False)]

def _process_annotation(self, anno, flat=False):
def _process_annotation(self, anno, flat=False, pop_is_uploaded=True, pop_uploaded_id=True) -> dict:
dflat = attrs.asdict(anno, filter=lambda a, v: v is not None)
if pop_is_uploaded:
dflat.pop(self.IS_UPLOADED_FIELD, None)
if pop_uploaded_id:
dflat.pop(self.UPLOADED_ID_FIELD, None)
dflat = self._process_spatial(dflat)
if flat:
return dflat
else:
return self._unflatten_spatial_points(dflat)

def _build_mixin(self):
def _build_mixin(self) -> type:
class AddAndValidate(object):
def __attrs_post_init__(inner_self):
d = self._process_annotation(inner_self)
Expand All @@ -190,9 +315,9 @@ def __attrs_post_init__(inner_self):

return AddAndValidate

def _make_anno_func(self, id_field=False, mixin=()):
def _make_anno_func(self, id_field=False, mixin=()) -> callable:
cdict = {}

if id_field:
cdict["id"] = attrs.field()
for prop, prop_name in zip(self._props, self._prop_names):
Expand All @@ -202,20 +327,23 @@ def _make_anno_func(self, id_field=False, mixin=()):
if prop not in self._required_props:
cdict[prop_name] = attrs.field(default=None)

cdict[self.IS_UPLOADED_FIELD] = attrs.field(default=False)
cdict[self.UPLOADED_ID_FIELD] = attrs.field(type=int, default=None)

return attrs.make_class(self.name, cdict, bases=mixin)

def _name_positions(self):
def _name_positions(self) -> list:
return [
x if x not in self._spatial_pts else f"{x}_position" for x in self._props
]

def _name_positions_required(self):
def _name_positions_required(self) -> list:
return [
x if x not in self._spatial_pts else f"{x}_position"
for x in self._required_props
]

def _process_spatial(self, d):
def _process_spatial(self, d) -> dict:
dout = {}
for k, v in d.items():
if isinstance(v, np.ndarray):
Expand All @@ -226,13 +354,13 @@ def _process_spatial(self, d):
dout[k] = v
return dout

def _process_spatial_point(self, v):
def _process_spatial_point(self, v) -> list:
if self._anno_scaling is None:
return v
else:
return [x * y for x, y in zip(v, self._anno_scaling)]

def _unflatten_spatial_points(self, d):
def _unflatten_spatial_points(self, d) -> dict:
dout = {}
for k, v in d.items():
if k in self._convert_pts:
Expand Down
Loading
Loading