From 829957ac0c2492e23a7ec355f2ae72ff72fada10 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:35:10 +0000 Subject: [PATCH 01/16] feat(sv_split): track max sv id to create new ids; convert ws seg to ocdbt --- pychunkedgraph/graph/ocdbt.py | 63 ++++++++++++++++++++ pychunkedgraph/ingest/cli.py | 2 + pychunkedgraph/ingest/cluster.py | 4 ++ pychunkedgraph/ingest/create/atomic_layer.py | 5 +- 4 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 pychunkedgraph/graph/ocdbt.py diff --git a/pychunkedgraph/graph/ocdbt.py b/pychunkedgraph/graph/ocdbt.py new file mode 100644 index 000000000..03c6d9b65 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt.py @@ -0,0 +1,63 @@ +import os +import numpy as np +import tensorstore as ts + +OCDBT_SEG_COMPRESSION_LEVEL = 17 + + +def get_seg_source_and_destination_ocdbt(ws_path: str, create: bool = False) -> tuple: + src_spec = { + "driver": "neuroglancer_precomputed", + "kvstore": ws_path, + } + src = ts.open(src_spec).result() + schema = src.schema + + ocdbt_path = os.path.join(ws_path, "ocdbt", "base") + dst_spec = { + "driver": "neuroglancer_precomputed", + "kvstore": { + "driver": "ocdbt", + "base": ocdbt_path, + "config": { + "compression": {"id": "zstd", "level": OCDBT_SEG_COMPRESSION_LEVEL}, + }, + }, + } + + dst = ts.open( + dst_spec, + create=create, + rank=schema.rank, + dtype=schema.dtype, + codec=schema.codec, + domain=schema.domain, + shape=schema.shape, + chunk_layout=schema.chunk_layout, + dimension_units=schema.dimension_units, + delete_existing=create, + ).result() + return (src, dst) + + +def copy_ws_chunk( + source, + destination, + chunk_size: tuple, + coords: list, + voxel_bounds: np.ndarray, +): + coords = np.array(coords, dtype=int) + chunk_size = np.array(chunk_size, dtype=int) + vx_start = coords * chunk_size + voxel_bounds[:, 0] + vx_end = vx_start + chunk_size + xE, yE, zE = voxel_bounds[:, 1] + + x0, y0, z0 = vx_start + x1, y1, z1 = vx_end + x1 = min(x1, xE) + y1 = min(y1, yE) + z1 = min(z1, zE) + + data = source[x0:x1, y0:y1, z0:z1].read().result() + destination[x0:x1, y0:y1, z0:z1].write(data).result() diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index c50525ec6..8d44bf276 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -23,6 +23,7 @@ from .simple_tests import run_all from .create.parent_layer import add_parent_chunk from ..graph.chunkedgraph import ChunkedGraph +from ..graph.ocdbt import get_seg_source_and_destination_ocdbt from ..utils.redis import get_redis_connection, keys as r_keys group_name = "ingest" @@ -71,6 +72,7 @@ def ingest_graph( imanager = IngestionManager(ingest_config, meta) enqueue_l2_tasks(imanager, create_atomic_chunk) + get_seg_source_and_destination_ocdbt(cg.meta, create=True) @ingest_cli.command("imanager") diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 360b5a15d..473a61b22 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -26,6 +26,7 @@ from .upgrade.parent_layer import update_chunk as update_parent_chunk from ..graph.edges import EDGE_TYPES, Edges, put_edges from ..graph import ChunkedGraph, ChunkedGraphMeta +from ..graph.ocdbt import copy_ws_chunk, get_seg_source_and_destination_ocdbt from ..graph.chunks.hierarchy import get_children_chunk_coords from ..graph.basetypes import NODE_ID from ..io.edges import get_chunk_edges @@ -141,6 +142,9 @@ def create_atomic_chunk(coords: Sequence[int]): logging.debug(f"{k}: {len(v)}") for k, v in chunk_edges_active.items(): logging.debug(f"active_{k}: {len(v)}") + + src, dst = get_seg_source_and_destination_ocdbt(imanager.cg.meta) + copy_ws_chunk(imanager.cg, coords, src, dst) _post_task_completion(imanager, 2, coords) diff --git a/pychunkedgraph/ingest/create/atomic_layer.py b/pychunkedgraph/ingest/create/atomic_layer.py index b226004f2..30043710d 100644 --- a/pychunkedgraph/ingest/create/atomic_layer.py +++ b/pychunkedgraph/ingest/create/atomic_layer.py @@ -32,7 +32,10 @@ def add_atomic_chunk( return chunk_ids = cg.get_chunk_ids_from_node_ids(chunk_node_ids) - assert len(np.unique(chunk_ids)) == 1 + assert len(np.unique(chunk_ids)) == 1, np.unique(chunk_ids) + + max_node_id = np.max(chunk_node_ids) + cg.id_client.set_max_node_id(chunk_ids[0], max_node_id) graph, _, _, unique_ids = build_gt_graph(chunk_edge_ids, make_directed=True) ccs = connected_components(graph) From 990673392b56dc7da353ca180ca5c828a49a5939 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:36:20 +0000 Subject: [PATCH 02/16] feat(sv_split): metadata changes to support ocdbt seg --- pychunkedgraph/graph/meta.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py index 83d670ffe..6a938f802 100644 --- a/pychunkedgraph/graph/meta.py +++ b/pychunkedgraph/graph/meta.py @@ -2,17 +2,16 @@ from datetime import timedelta from typing import Dict from typing import List -from typing import Tuple from typing import Sequence from collections import namedtuple import numpy as np from cloudvolume import CloudVolume +from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt + from .utils.generic import compute_bitmasks from .chunks.utils import get_chunks_boundary -from ..utils.redis import keys as r_keys -from ..utils.redis import get_rq_queue from ..utils.redis import get_redis_connection @@ -64,9 +63,11 @@ def __init__( self._custom_data = custom_data self._ws_cv = None + self._ws_ocdbt = None self._layer_bounds_d = None self._layer_count = None self._bitmasks = None + self._ocdbt_seg = None @property def graph_config(self): @@ -91,15 +92,33 @@ def ws_cv(self): # useful to avoid md5 errors on high gcs load redis = get_redis_connection() cached_info = json.loads(redis.get(cache_key)) - self._ws_cv = CloudVolume(self._data_source.WATERSHED, info=cached_info) + self._ws_cv = CloudVolume( + self._data_source.WATERSHED, info=cached_info, progress=False + ) except Exception: - self._ws_cv = CloudVolume(self._data_source.WATERSHED) + self._ws_cv = CloudVolume(self._data_source.WATERSHED, progress=False) try: redis.set(cache_key, json.dumps(self._ws_cv.info)) except Exception: ... return self._ws_cv + @property + def ocdbt_seg(self) -> bool: + if self._ocdbt_seg is None: + self._ocdbt_seg = self._custom_data.get("seg", {}).get("ocdbt", False) + return self._ocdbt_seg + + @property + def ws_ocdbt(self): + assert self.ocdbt_seg, "make sure this pcg has segmentation in ocdbt format" + if self._ws_ocdbt: + return self._ws_ocdbt + + _, _ocdbt_seg = get_seg_source_and_destination_ocdbt(self.data_source.WATERSHED) + self._ws_ocdbt = _ocdbt_seg + return self._ws_ocdbt + @property def resolution(self): return self.ws_cv.resolution # pylint: disable=no-member @@ -235,11 +254,14 @@ def split_bounding_offset(self): @property def dataset_info(self) -> Dict: info = self.ws_cv.info # pylint: disable=no-member - info.update( { "chunks_start_at_voxel_offset": True, - "data_dir": self.data_source.WATERSHED, + "data_dir": ( + self.ws_ocdbt.kvstore.base.url + if self.ocdbt_seg + else self.data_source.WATERSHED + ), "graph": { "chunk_size": self.graph_config.CHUNK_SIZE, "bounding_box": [2048, 2048, 512], From 8111ddd8c2de45c725e797ef9b8dea5529d39772 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:39:58 +0000 Subject: [PATCH 03/16] feat(sv_split): split sv, update seg and edges, read and write new edges from pcg --- pychunkedgraph/graph/chunkedgraph.py | 42 +- pychunkedgraph/graph/chunks/utils.py | 54 +- pychunkedgraph/graph/cutting_sv.py | 1284 ++++++++++++++++++++++ pychunkedgraph/graph/edits_sv.py | 439 ++++++++ pychunkedgraph/graph/types.py | 3 +- pychunkedgraph/graph/utils/__init__.py | 1 + pychunkedgraph/graph/utils/generic.py | 12 + pychunkedgraph/graph/utils/id_helpers.py | 6 +- pychunkedgraph/meshing/meshgen_utils.py | 18 +- requirements.in | 3 + 10 files changed, 1828 insertions(+), 34 deletions(-) create mode 100644 pychunkedgraph/graph/cutting_sv.py create mode 100644 pychunkedgraph/graph/edits_sv.py diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 89282a58c..4dbdcdac9 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -672,22 +672,44 @@ def get_subgraph_leaves( self, node_id_or_ids, bbox, bbox_is_coordinate, False, True ) - def get_fake_edges( + def get_edited_edges( self, chunk_ids: np.ndarray, time_stamp: datetime.datetime = None ) -> typing.Dict: + """ + Edges stored within a pcg that were created as a result of edits. + Either 'fake' edges that were adding for a merge edit; + Or 'split' edges resulting from a supervoxel split. + """ result = {} - fake_edges_d = self.client.read_nodes( + properties = [ + attributes.Connectivity.FakeEdges, + attributes.Connectivity.SplitEdges, + attributes.Connectivity.Affinity, + attributes.Connectivity.Area, + ] + _edges_d = self.client.read_nodes( node_ids=chunk_ids, - properties=attributes.Connectivity.FakeEdges, + properties=properties, end_time=time_stamp, end_time_inclusive=True, fake_edges=True, ) - for id_, val in fake_edges_d.items(): - edges = np.concatenate( - [np.asarray(e.value, dtype=basetypes.NODE_ID) for e in val] - ) - result[id_] = Edges(edges[:, 0], edges[:, 1]) + for id_, val in _edges_d.items(): + edges = val.get(attributes.Connectivity.FakeEdges, []) + edges = np.concatenate([types.empty_2d, *[e.value for e in edges]]) + fake_edges_ = Edges(edges[:, 0], edges[:, 1]) + + edges = val.get(attributes.Connectivity.SplitEdges, []) + edges = np.concatenate([types.empty_2d, *[e.value for e in edges]]) + + aff = val.get(attributes.Connectivity.Affinity, []) + aff = np.concatenate([types.empty_affinities, *[e.value for e in aff]]) + + areas = val.get(attributes.Connectivity.Area, []) + areas = np.concatenate([types.empty_areas, *[e.value for e in areas]]) + split_edges_ = Edges(edges[:, 0], edges[:, 1], affinities=aff, areas=areas) + + result[id_] = fake_edges_ + split_edges_ return result def copy_fake_edges(self, chunk_id: np.uint64) -> None: @@ -726,10 +748,10 @@ def get_l2_agglomerations( if self.mock_edges is None: edges_d = self.read_chunk_edges(chunk_ids) - fake_edges = self.get_fake_edges(chunk_ids) + edited_edges = self.get_edited_edges(chunk_ids) all_chunk_edges = reduce( lambda x, y: x + y, - chain(edges_d.values(), fake_edges.values()), + chain(edges_d.values(), edited_edges.values()), Edges([], []), ) if self.mock_edges is not None: diff --git a/pychunkedgraph/graph/chunks/utils.py b/pychunkedgraph/graph/chunks/utils.py index 5b6d0ae78..0e39fbf9f 100644 --- a/pychunkedgraph/graph/chunks/utils.py +++ b/pychunkedgraph/graph/chunks/utils.py @@ -169,9 +169,7 @@ def _compute_chunk_id( z: int, ) -> np.uint64: s_bits_per_dim = meta.bitmasks[layer] - if not ( - x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim - ): + if not (x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim): raise ValueError( f"Coordinate is out of range \ layer: {layer} bits/dim {s_bits_per_dim}. \ @@ -284,3 +282,53 @@ def get_l2chunkids_along_boundary(cg_meta, mlayer: int, coord_a, coord_b, paddin l2chunk_ids_a = get_chunk_ids_from_coords(cg_meta, 2, l2chunks_a) l2chunk_ids_b = get_chunk_ids_from_coords(cg_meta, 2, l2chunks_b) return l2chunk_ids_a, l2chunk_ids_b + + +def chunks_overlapping_bbox(bbox_min, bbox_max, chunk_size) -> dict: + """ + Find octree chunks overlapping with a bounding box in 3D + and return a dictionary mapping chunk indices to clipped bounding boxes. + """ + bbox_min = np.asarray(bbox_min, dtype=int) + bbox_max = np.asarray(bbox_max, dtype=int) + chunk_size = np.asarray(chunk_size, dtype=int) + + start_idx = np.floor_divide(bbox_min, chunk_size).astype(int) + end_idx = np.floor_divide(bbox_max, chunk_size).astype(int) + + ix = np.arange(start_idx[0], end_idx[0] + 1) + iy = np.arange(start_idx[1], end_idx[1] + 1) + iz = np.arange(start_idx[2], end_idx[2] + 1) + grid = np.stack(np.meshgrid(ix, iy, iz, indexing="ij"), axis=-1, dtype=int) + grid = grid.reshape(-1, 3) + + chunk_min = grid * chunk_size + chunk_max = chunk_min + chunk_size + clipped_min = np.maximum(chunk_min, bbox_min) + clipped_max = np.minimum(chunk_max, bbox_max) + return { + tuple(idx): np.stack([cmin, cmax], axis=0, dtype=int) + for idx, cmin, cmax in zip(grid, clipped_min, clipped_max) + } + + +def get_neighbors(coord, inclusive: bool = True, min_coord=None, max_coord=None): + """ + Get all valid coordinates in the 3×3×3 cube around a given chunk, + including the chunk itself (if inclusive=True), + respecting bounding box constraints. + """ + offsets = np.array(np.meshgrid([-1, 0, 1], [-1, 0, 1], [-1, 0, 1])).T.reshape(-1, 3) + if not inclusive: + offsets = offsets[~np.all(offsets == 0, axis=1)] + + neighbors = np.array(coord) + offsets + if min_coord is None: + min_coord = (0, 0, 0) + min_coord = np.array(min_coord) + neighbors = neighbors[(neighbors >= min_coord).all(axis=1)] + + if max_coord is not None: + max_coord = np.array(max_coord) + neighbors = neighbors[(neighbors <= max_coord).all(axis=1)] + return neighbors diff --git a/pychunkedgraph/graph/cutting_sv.py b/pychunkedgraph/graph/cutting_sv.py new file mode 100644 index 000000000..5f9ba58c5 --- /dev/null +++ b/pychunkedgraph/graph/cutting_sv.py @@ -0,0 +1,1284 @@ +from time import perf_counter + +import numpy as np +from typing import Dict, Tuple, Optional, Sequence +from scipy.spatial import cKDTree + + +# EDT backends: prefer Seung-Lab edt, fallback to scipy.ndimage +try: + from edt import edt as _edt_fast + + _HAVE_EDT_FAST = True +except Exception: + _HAVE_EDT_FAST = False + +from scipy import ndimage as ndi +from scipy.spatial import cKDTree +from skimage.graph import MCP_Geometric +from skimage.morphology import ( + ball, +) # keep only ball; use ndi.binary_dilation everywhere + +# ---------- Fast CC wrappers ---------- +try: + import cc3d + + _HAVE_CC3D = True +except Exception: + _HAVE_CC3D = False + from skimage.measure import label as _sk_label + +try: + import fastremap as _fr + + _HAVE_FASTREMAP = True +except Exception: + _HAVE_FASTREMAP = False + + +def _cc_label_26(mask: np.ndarray): + """ + Fast 3D connected components (26-connectivity). + Returns (labels:int32, n_components:int). + """ + if _HAVE_CC3D: + lbl = cc3d.connected_components( + mask.astype(np.uint8, copy=False), connectivity=26, out_dtype=np.uint32 + ) + return lbl, int(lbl.max()) + # Fallback: skimage (connectivity=3 ~ 26-neighborhood) + lbl = _sk_label(mask, connectivity=3).astype(np.int32, copy=False) + return lbl, int(lbl.max()) + + +def _largest_component_id(lbl: np.ndarray): + """ + Return the label ID (>=1) of the largest component in 'lbl'. + lbl should already be a CC label image where 0=background. + """ + if _HAVE_FASTREMAP: + u, counts = _fr.unique(lbl, return_counts=True) + if u.size: + bg = np.where(u == 0)[0] + if bg.size: + counts[bg[0]] = 0 + return int(u[np.argmax(counts)]) + return 0 + cnt = np.bincount(lbl.ravel()) + if cnt.size: + cnt[0] = 0 + return int(np.argmax(cnt)) if cnt.size else 0 + + +# ========================= +# Order / utility helpers +# ========================= +def _to_zyx_sampling(vs, vox_order): + vs = tuple(map(float, vs)) + if vox_order.lower() == "xyz": # (x,y,z) -> (z,y,x) + return (vs[2], vs[1], vs[0]) + if vox_order.lower() == "zyx": + return vs + raise ValueError("vox_order must be 'xyz' or 'zyx'") + + +def _to_internal_zyx_volume(vol, vol_order): + if vol_order.lower() == "zyx": + return vol, False + if vol_order.lower() == "xyz": # (x,y,z) -> (z,y,x) + return np.transpose(vol, (2, 1, 0)), True + raise ValueError("vol_order must be 'xyz' or 'zyx'") + + +def _from_internal_zyx_volume(vol_zyx, vol_order): + if vol_order.lower() == "zyx": + return vol_zyx + if vol_order.lower() == "xyz": # (z,y,x) -> (x,y,z) + return np.transpose(vol_zyx, (2, 1, 0)) + raise ValueError("vol_order must be 'xyz' or 'zyx'") + + +def _seeds_to_zyx(seeds, seed_order): + arr = np.asarray(seeds, dtype=float).reshape(-1, 3) + if seed_order.lower() == "xyz": + arr = arr[:, [2, 1, 0]] # (x,y,z) -> (z,y,x) + elif seed_order.lower() != "zyx": + raise ValueError("seed_order must be 'xyz' or 'zyx'") + return np.round(arr).astype(int) + + +def _seeds_from_zyx(seeds_zyx, seed_order): + arr = np.asarray(seeds_zyx, dtype=int).reshape(-1, 3) + if seed_order.lower() == "xyz": + return arr[:, [2, 1, 0]] # (z,y,x) -> (x,y,z) + elif seed_order.lower() == "zyx": + return arr + else: + raise ValueError("seed_order must be 'xyz' or 'zyx'") + + +# ========================= +# Snapping (KDTree-based) +# ========================= +def _extract_mask_boundary(mask, erosion_iters=1): + """ + Extract boundary voxels of a 3D mask using binary erosion. + Boundary = mask & (~eroded(mask)) + + Parameters: + mask : 3D boolean array + erosion_iters : number of erosion iterations (higher removes thicker border) + + Returns: + boundary_mask : 3D boolean array of the same shape + """ + if erosion_iters < 1: + # No erosion => boundary = mask (not recommended unless extremely thin structures) + return mask.copy() + + structure = np.ones((3, 3, 3), dtype=bool) + interior = ndi.binary_erosion( + mask, structure=structure, iterations=erosion_iters, border_value=0 + ) + boundary = mask & (~interior) + return boundary + + +def _downsample_points(points, mode="stride", stride=2, target=None, rng=None): + """ + Downsample a set of points (N,3) by either: + - 'stride': take one every 'stride' points (fast, deterministic), + - 'random': keep ~target points uniformly at random. + + Args: + points : (N, 3) int or float array of coordinates + mode : 'stride' or 'random' + stride : int >= 1 (for 'stride' mode) + target : number of points to keep (for 'random' mode); if None, default is 50k + rng : np.random.Generator for reproducible random sampling + + Returns: + (M, 3) array with M <= N + """ + n = points.shape[0] + if n == 0: + return points + + if mode == "stride": + stride = max(1, int(stride)) + return points[::stride] + + elif mode == "random": + if target is None: + target = min(n, 50_000) # default target + target = max(1, int(target)) + if target >= n: + return points + if rng is None: + rng = np.random.default_rng() + idx = rng.choice(n, size=target, replace=False) + return points[idx] + + else: + raise ValueError("downsample mode must be 'stride' or 'random'") + + +def snap_seeds_to_segment( + seeds_xyz, + mask, + mask_order="zyx", + voxel_size=(1.0, 1.0, 1.0), + use_boundary=True, + erosion_iters=1, + downsample=True, + downsample_mode="stride", # 'stride' or 'random' + downsample_stride=2, # used if mode='stride' + downsample_target=None, # used if mode='random' + rng=None, + return_index=False, + leafsize=16, + log=lambda x: None, + tag="snap", + method="kdtree", # accepted for compatibility; only 'kdtree' currently +): + """ + Snap seeds (in XYZ) to the closest True voxel of a 3D mask using cKDTree over + a *reduced* set of candidate voxels: + - boundary-only (mask & ~eroded(mask)), if use_boundary=True + - optionally downsampled (stride or random) + + This approach works well for speed while retaining high accuracy for snapping. + + Parameters: + seeds_xyz : (N,3) float or int array in XYZ order. + mask : 3D boolean array; binary segment. + mask_order : 'zyx' (default) or 'xyz' indicating memory layout of mask. + voxel_size : (vx, vy, vz) in XYZ physical units (e.g., (8.0, 8.0, 40.0)). + use_boundary : If True, only use boundary voxels for KDTree. + erosion_iters : Number of erosion iterations for boundary extraction. + downsample : If True, further reduce boundary points (stride or random). + downsample_mode : 'stride' or 'random' for boundary sampling. + downsample_stride : If stride mode, use every Nth boundary voxel. + downsample_target : If random mode, target number of boundary points to keep. + rng : Optional np.random.Generator for reproducible random sampling. + return_index : If True, also return indices of nearest boundary points. + leafsize : cKDTree leafsize parameter. + log : callable for logging + tag : string to prefix timings + method : currently only 'kdtree' supported. Present for backward compatibility. + + Returns: + snapped_xyz : (N,3) int array in XYZ order, coordinates within volume bounds. + match_idx : (optional) indices into the candidate points array, if return_index=True. + + Notes: + - Seeds outside the volume are supported; they will snap to the nearest segment voxel. + - If use_boundary=True yields no boundary (thin segment), we fall back to the full mask. + - If the mask is empty, we raise ValueError. + """ + t0 = perf_counter() + if method != "kdtree": + log(f"[{tag}] Warning: 'method={method}' not supported; using 'kdtree'.") + + # Validate mask + if mask.ndim != 3: + raise ValueError("mask must be a 3D boolean array") + if mask.dtype != bool: + mask = mask.astype(bool) + + if mask_order not in ("zyx", "xyz"): + raise ValueError("mask_order must be 'zyx' or 'xyz'") + + # Optional boundary extraction for speed + tb = perf_counter() + if use_boundary: + candidate_mask = _extract_mask_boundary(mask, erosion_iters=erosion_iters) + # Fallback to full mask if boundary is empty + if not candidate_mask.any(): + candidate_mask = mask + log(f"[{tag}] boundary empty → fallback to full mask") + else: + candidate_mask = mask + log(f"[{tag}] candidate extraction | {perf_counter()-tb:.3f}s") + + # Obtain candidate voxel coordinates in XYZ order + tc = perf_counter() + if mask_order == "zyx": + # mask shape is (Z, Y, X), np.where -> (z, y, x) + zc, yc, xc = np.where(candidate_mask) + points_xyz = np.stack([xc, yc, zc], axis=1) + max_x, max_y, max_z = mask.shape[2] - 1, mask.shape[1] - 1, mask.shape[0] - 1 + else: + # mask shape is (X, Y, Z), np.where -> (x, y, z) + xc, yc, zc = np.where(candidate_mask) + points_xyz = np.stack([xc, yc, zc], axis=1) + max_x, max_y, max_z = mask.shape[0] - 1, mask.shape[1] - 1, mask.shape[2] - 1 + log( + f"[{tag}] candidate coordinates | {perf_counter()-tc:.3f}s (n={len(points_xyz)})" + ) + + if points_xyz.shape[0] == 0: + raise ValueError( + "The mask (or boundary) contains no True voxels (empty segment)." + ) + + # Optional: further downsample candidate points + td = perf_counter() + if downsample: + before = len(points_xyz) + points_xyz = _downsample_points( + points_xyz, + mode=downsample_mode, + stride=downsample_stride, + target=downsample_target, + rng=rng, + ) + after = len(points_xyz) + log(f"[{tag}] downsample points {before} → {after} | {perf_counter()-td:.3f}s") + + # Prepare seeds array + seeds_xyz = np.asarray(seeds_xyz, dtype=np.float64) + if seeds_xyz.ndim == 1: + seeds_xyz = seeds_xyz[None, :] + if seeds_xyz.shape[1] != 3: + raise ValueError("seeds_xyz must be shape (N, 3)") + + # Scale coordinates to physical space to respect anisotropy + vx, vy, vz = voxel_size + scale = np.array([vx, vy, vz], dtype=np.float64) + + points_scaled = points_xyz * scale[None, :] + seeds_scaled = seeds_xyz * scale[None, :] + + # cKDTree nearest neighbor lookup + te = perf_counter() + tree = cKDTree(points_scaled, leafsize=leafsize) + _, nn_indices = tree.query(seeds_scaled, k=1, workers=-1) + log(f"[{tag}] KDTree build+query | {perf_counter()-te:.3f}s") + + # Map back to integer voxel coords (XYZ) + snapped_xyz = points_xyz[nn_indices].astype(np.int64) + + # Ensure snapped coords are valid (should already be in bounds) + snapped_xyz[:, 0] = np.clip(snapped_xyz[:, 0], 0, max_x) + snapped_xyz[:, 1] = np.clip(snapped_xyz[:, 1], 0, max_y) + snapped_xyz[:, 2] = np.clip(snapped_xyz[:, 2], 0, max_z) + + log(f"[{tag}] snapped {len(seeds_xyz)} seeds | total {perf_counter()-t0:.3f}s") + if return_index: + return snapped_xyz, nn_indices + else: + return snapped_xyz + + +# ============================================================ +# EDT wrapper (Seung-Lab edt preferred, fallback to scipy) +# ============================================================ +def _compute_edt(mask: np.ndarray, sampling_zyx, log=lambda x: None, tag="edt"): + """ + Compute Euclidean distance transform using Seung-Lab edt if available, + otherwise fallback to scipy.ndimage.distance_transform_edt. + + - mask: boolean array in ZYX order + - sampling_zyx: anisotropy tuple in ZYX (float) + """ + t0 = perf_counter() + if _HAVE_EDT_FAST: + dist = _edt_fast(mask.astype(np.uint8, copy=False), anisotropy=sampling_zyx) + log(f"[{tag}] Seung-Lab edt | {perf_counter()-t0:.3f}s") + return dist + else: + dist = ndi.distance_transform_edt(mask, sampling=sampling_zyx) + log(f"[{tag}] SciPy EDT | {perf_counter()-t0:.3f}s") + return dist + + +# ------------------------------------------------------------ +# Helpers for upsampling +# ------------------------------------------------------------ +def _upsample_bool(mask_ds, steps, target_shape): + up = mask_ds.repeat(steps[0], 0).repeat(steps[1], 1).repeat(steps[2], 2) + return up[: target_shape[0], : target_shape[1], : target_shape[2]] + + +def _upsample_labels(lbl_ds, steps, target_shape): + up = lbl_ds.repeat(steps[0], 0).repeat(steps[1], 1).repeat(steps[2], 2) + return up[: target_shape[0], : target_shape[1], : target_shape[2]] + + +# ============================================================ +# Combined connector (ROI + DS + MST paths) — uses snapping + fast EDT +# ============================================================ +def connect_both_seeds_via_ridge( + binary_sv: np.ndarray, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + *, + vol_order: str = "xyz", + vox_order: str = "xyz", + seed_order: str = "xyz", + ridge_power: float = 2.0, + roi_pad_zyx=(24, 48, 48), + downsample=(2, 2, 1), + refine_fullres_when_fail: bool = True, + snap_method: str = "kdtree", + snap_kwargs: dict | None = None, + verbose: bool = True, +): + def log(msg: str): + if verbose: + print(msg, flush=True) + + def _bbox_pad_zyx(points_zyx, shape, pad=(24, 48, 48)): + pts = np.asarray(points_zyx, int) + if pts.size == 0: + return (0, 0, 0, shape[0], shape[1], shape[2]) + z0, y0, x0 = pts.min(0) + z1, y1, x1 = pts.max(0) + 1 + z0 = max(0, z0 - pad[0]) + y0 = max(0, y0 - pad[1]) + x0 = max(0, x0 - pad[2]) + z1 = min(shape[0], z1 + pad[0]) + y1 = min(shape[1], y1 + pad[1]) + x1 = min(shape[2], x1 + pad[2]) + return (z0, y0, x0, z1, y1, x1) + + def _mst_edges_phys(pts_zyx, sampling): + P = np.asarray(pts_zyx, float) + if len(P) <= 1: + return [] + S = np.array(sampling, float)[None, :] + phys = P * S + n = len(P) + in_tree = np.zeros(n, bool) + in_tree[0] = True + best = np.full(n, np.inf) + parent = np.full(n, -1, int) + d0 = np.sqrt(((phys - phys[0]) ** 2).sum(1)) + best[:] = d0 + best[0] = np.inf + parent[:] = 0 + edges = [] + for _ in range(n - 1): + i = int(np.argmin(best)) + if not np.isfinite(best[i]): + break + edges.append((int(parent[i]), i)) + in_tree[i] = True + best[i] = np.inf + di = np.sqrt(((phys - phys[i]) ** 2).sum(1)) + relax = (~in_tree) & (di < best) + parent[relax] = i + best[relax] = di[relax] + return edges + + t0 = perf_counter() + log( + f"[connect] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}" + ) + log( + f"[connect] mask shape: {binary_sv.shape}, ridge_power={ridge_power}, ds={downsample}" + ) + + sv_zyx, _ = _to_internal_zyx_volume(binary_sv, vol_order) + sampling = _to_zyx_sampling(voxel_size, vox_order) + + # SNAP seeds to mask + A_in_zyx = _seeds_to_zyx(seeds_a, seed_order) + B_in_zyx = _seeds_to_zyx(seeds_b, seed_order) + + # Default snapping config; override via snap_kwargs + snap_cfg = dict( + use_boundary=True, + erosion_iters=1, + downsample=True, + downsample_mode="random", + downsample_target=50_000, + method=snap_method, # allow pass-through compatibility + ) + if snap_kwargs is not None: + snap_cfg.update(snap_kwargs) + + def _snap(pts_zyx, name): + if pts_zyx.size == 0: + return np.empty((0, 3), dtype=int) + # Convert ZYX -> XYZ for snapper + pts_xyz = pts_zyx[:, [2, 1, 0]] + # Use snapping over full 3D sv_zyx with ZYX mask + snapped_xyz = snap_seeds_to_segment( + pts_xyz, + mask=sv_zyx, + mask_order="zyx", + voxel_size=( + sampling[2], + sampling[1], + sampling[0], + ), # convert ZYX->XYZ spacing + log=log, + tag=f"{name}@snap", + **snap_cfg, + ) + # Back to ZYX + return snapped_xyz[:, [2, 1, 0]] + + A_zyx = _snap(A_in_zyx, "A") + B_zyx = _snap(B_in_zyx, "B") + + if len(A_zyx) == 0 or len(B_zyx) == 0: + log("[connect] after snapping, one side has no seeds; skipping connection") + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + (len(A_zyx) > 0), + (len(B_zyx) > 0), + ) + + # ROI for speed + z0, y0, x0, z1, y1, x1 = _bbox_pad_zyx( + np.vstack([A_zyx, B_zyx]), sv_zyx.shape, pad=roi_pad_zyx + ) + roi = sv_zyx[z0:z1, y0:y1, x0:x1] + log(f"[connect] ROI: z[{z0}:{z1}] y[{y0}:{y1}] x[{x0}:{x1}] → shape {roi.shape}") + + # Downsample ROI + sz, sy, sx = map(int, downsample) + ti_ds = perf_counter() + if (sz, sy, sx) != (1, 1, 1): + roi_ds = roi[::sz, ::sy, ::sx] + else: + roi_ds = roi + sampling_ds = (sampling[0] * sz, sampling[1] * sy, sampling[2] * sx) + log( + f"[connect] ROI downsampled {roi.shape} -> {roi_ds.shape} | {perf_counter()-ti_ds:.3f}s" + ) + + # Robust seed placement on the downsampled grid: + # (1) Map to ROI-local coords + # (2) Divide by (sz,sy,sx) to approximate DS coords + # (3) SNAP them to the nearest True voxel in roi_ds using KDTree + def _to_roi_ds_snapped(pts_zyx, name="seedDS"): + if pts_zyx.size == 0: + return np.empty((0, 3), dtype=int) + local = np.asarray(pts_zyx, int) - np.array([z0, y0, x0]) # roi-local + seeds_ds = local / np.array( + [sz, sy, sx], dtype=float + ) # DS coordinates (float OK) + # Convert to XYZ for snapper + seeds_ds_xyz = seeds_ds[:, [2, 1, 0]] + try: + snapped_ds_xyz = snap_seeds_to_segment( + seeds_ds_xyz, + mask=roi_ds, + mask_order="zyx", + voxel_size=(sampling_ds[2], sampling_ds[1], sampling_ds[0]), + log=log, + tag=f"{name}@roi_ds", + use_boundary=False, + downsample=False, + method="kdtree", + ) + snapped_ds_zyx = snapped_ds_xyz[:, [2, 1, 0]] + return snapped_ds_zyx.astype(int) + except ValueError as e: + # If roi_ds is empty or degenerate, bail out gracefully: + log( + f"[{name}@roi_ds] snapping failed ({e}); falling back to nearest-int grid & mask check." + ) + approx = np.floor(seeds_ds + 0.5).astype(int) + Z, Y, X = roi_ds.shape + approx[:, 0] = np.clip(approx[:, 0], 0, Z - 1) + approx[:, 1] = np.clip(approx[:, 1], 0, Y - 1) + approx[:, 2] = np.clip(approx[:, 2], 0, X - 1) + # Keep only those approx coords that are inside mask + valid = [tuple(p) for p in approx if roi_ds[tuple(p)]] + return np.array(valid, dtype=int) + + A_ds = _to_roi_ds_snapped(A_zyx, "A") + B_ds = _to_roi_ds_snapped(B_zyx, "B") + + okA = len(A_ds) >= 1 + okB = len(B_ds) >= 1 + if not (okA and okB): + log( + "[connect] seeds disappeared or failed to map on DS grid; consider smaller ds or use_boundary=False/downsample=False in snapping." + ) + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + okA, + okB, + ) + + # EDT and cost on DS ROI (Seung-Lab edt if available) + t1 = perf_counter() + dist = _compute_edt(roi_ds, sampling_ds, log=log, tag="connect:EDT") + if dist.max() <= 0: + log("[connect] empty EDT in ROI; skipping connection") + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + False, + False, + ) + dn = dist / dist.max() + eps = 1e-6 + cost = np.full_like(dn, 1e12, dtype=float) + cost[roi_ds] = 1.0 / (eps + np.clip(dn[roi_ds], 0, 1) ** max(0.0, ridge_power)) + log(f"[connect] EDT/cost ready on DS-ROI | {perf_counter()-t1:.3f}s") + + # Shortest paths via MST + def _path_mask_ds(start, end): + tmcp = perf_counter() + mcp = MCP_Geometric(cost, sampling=sampling_ds) + costs, _ = mcp.find_costs([tuple(start)], find_all_ends=False) + mid = perf_counter() + v = costs[tuple(end)] + if not np.isfinite(v): + log( + f"[MCP] start={tuple(start)} -> end={tuple(end)} FAILED | setup+run={mid-tmcp:.3f}s" + ) + return None + path = np.asarray(mcp.traceback(tuple(end)), int) + m = np.zeros_like(roi_ds, bool) + m[tuple(path.T)] = True + log( + f"[MCP] start={tuple(start)} -> end={tuple(end)} OK | total={perf_counter()-tmcp:.3f}s" + ) + return m + + def _augment_team_ds(team_name, pts_ds): + if len(pts_ds) <= 1: + return np.zeros_like(roi_ds, bool), True + edges = _mst_edges_phys(pts_ds, sampling_ds) + pmask = np.zeros_like(roi_ds, bool) + ok = True + for i, j in edges: + m = _path_mask_ds(pts_ds[i], pts_ds[j]) + if m is None: + log(f"[connect:{team_name}] DS path FAILED for edge {i}-{j}") + ok = False + if refine_fullres_when_fail: + # fallback full-res EDT and path + tfr = perf_counter() + dist_fr = _compute_edt( + roi, sampling, log=log, tag="connect:EDT(fullres)" + ) + dnm = dist_fr / (dist_fr.max() if dist_fr.max() > 0 else 1.0) + cost_fr = np.full_like(dist_fr, 1e12, dtype=float) + cost_fr[roi] = 1.0 / ( + eps + np.clip(dnm[roi], 0, 1) ** max(0.0, ridge_power) + ) + s = np.array(pts_ds[i]) * np.array([sz, sy, sx]) + e = np.array(pts_ds[j]) * np.array([sz, sy, sx]) + mcp_fr = MCP_Geometric(cost_fr, sampling=sampling) + costs_fr, _ = mcp_fr.find_costs([tuple(s)], find_all_ends=False) + if np.isfinite(costs_fr[tuple(e)]): + path_fr = np.asarray(mcp_fr.traceback(tuple(e)), int) + m_fr = np.zeros_like(roi, bool) + m_fr[tuple(path_fr.T)] = True + m = m_fr[::sz, ::sy, ::sx] + ok = True + log( + f"[connect:{team_name}] fallback full-res path OK | {perf_counter()-tfr:.3f}s" + ) + else: + log( + f"[connect:{team_name}] Full-res ROI path also FAILED for edge {i}-{j}" + ) + m = None + if m is not None: + pmask |= m + return pmask, ok + + t_aug = perf_counter() + pA_ds, okA2 = _augment_team_ds("A", A_ds) + pB_ds, okB2 = _augment_team_ds("B", B_ds) + okA &= okA2 + okB &= okB2 + log(f"[connect] MST+paths built | {perf_counter()-t_aug:.3f}s") + + if not (okA and okB): + log( + "[connect] connection failed for at least one team — consider smaller downsample or refine_fullres_when_fail." + ) + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + okA, + okB, + ) + + # Up-project to full resolution and dilate + pA = _upsample_bool(pA_ds, (sz, sy, sx), roi.shape) & roi + pB = _upsample_bool(pB_ds, (sz, sy, sx), roi.shape) & roi + struc = ball(1) + tpost = perf_counter() + pA = ndi.binary_dilation(pA, structure=struc) & roi + pB = ndi.binary_dilation(pB, structure=struc) & roi + log(f"[connect] postproc dilation on paths | {perf_counter()-tpost:.3f}s") + + A_aug = set(map(tuple, A_zyx)) + B_aug = set(map(tuple, B_zyx)) + Az, Ay, Ax = np.nonzero(pA) + Bz, By, Bx = np.nonzero(pB) + for z, y, x in zip(Az, Ay, Ax): + A_aug.add((z0 + z, y0 + y, x0 + x)) + for z, y, x in zip(Bz, By, Bx): + B_aug.add((z0 + z, y0 + y, x0 + x)) + + A_aug = _seeds_from_zyx(np.array(sorted(list(A_aug)), int), seed_order) + B_aug = _seeds_from_zyx(np.array(sorted(list(B_aug)), int), seed_order) + log( + f"[connect] done; +{len(A_aug)-len(seeds_a)} vox for A, +{len(B_aug)-len(seeds_b)} for B | total {perf_counter()-t0:.3f}s" + ) + return A_aug, B_aug, True, True + + +def split_supervoxel_growing( + binary_sv: np.ndarray, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + *, + # conventions / orders + vol_order: str = "xyz", + vox_order: str = "xyz", + seed_order: str = "xyz", + # geometry / cost + halo: int = 1, + gamma_neck: float = 1.6, # boundary slowdown + k_prox: float = 2.0, # proximity boost strength + lambda_prox: float = 1.0, # proximity decay + narrow_band_rel: float = 0.08, # relative difference threshold + nb_dilate: int = 1, # dilate band to stabilize + # optional: compute TA/TB on a downsampled grid + downsample_geodesic: tuple | None = None, # e.g. (1,2,2) + # post-processing / guarantees + allow_third_label: bool = True, + enforce_single_cc: bool = True, + # final validation + check_seeds_same_cc: bool = True, + raise_if_seed_split: bool = True, + raise_if_multi_cc: bool = False, + # snapping control (NEW) + snap_method: str = "kdtree", + snap_kwargs: dict | None = None, + # logging + verbose: bool = True, +): + def log(msg: str): + if verbose: + print(msg, flush=True) + + # Helpers reused from the module: _cc_label_26, _largest_component_id, _to_internal_zyx_volume, _from_internal_zyx_volume + # _seeds_to_zyx, _compute_edt, etc. are assumed available. + + # ---------- helpers ---------- + def _enforce_single_component(out_labels, lab, seed_pts_global, allow3=True): + t = perf_counter() + mask = out_labels == lab + if not np.any(mask): + return 0, 0 + comp, ncomp = _cc_label_26(mask) + if ncomp <= 1: + log(f"[single-cc:{lab}] ncomp=1 | {perf_counter()-t:.3f}s") + return 1, 0 + + keep_ids = set() + for z, y, x in seed_pts_global: + if ( + 0 <= z < out_labels.shape[0] + and 0 <= y < out_labels.shape[1] + and 0 <= x < out_labels.shape[2] + ): + if out_labels[z, y, x] == lab: + cid = comp[z, y, x] + if cid > 0: + keep_ids.add(int(cid)) + + if not keep_ids: + keep_ids = {_largest_component_id(comp)} + + lut = np.zeros(ncomp + 1, dtype=np.bool_) + lut[list(keep_ids)] = True + bad_mask = (comp > 0) & (~lut[comp]) + moved = int(bad_mask.sum()) + if allow3 and moved: + out_labels[bad_mask] = 3 + log( + f"[single-cc:{lab}] kept={len(keep_ids)}, moved_to_3={moved} | {perf_counter()-t:.3f}s" + ) + return len(keep_ids), moved + + def _resolve_label3_touching_vectorized( + out_labels, seedsA=None, seedsB=None, sampling=(1, 1, 1) + ): + t0 = perf_counter() + comp3, n3 = _cc_label_26(out_labels == 3) + n3_vox = int((out_labels == 3).sum()) + log(f"[touching] n3 comps={n3}, vox={n3_vox}") + if n3 == 0: + log(f"[touching] no label-3 components | {perf_counter()-t0:.3f}s") + return 0, 0 + + t1 = perf_counter() + struc = np.ones((3, 3, 3), bool) + N1 = ndi.binary_dilation(out_labels == 1, structure=struc) & (comp3 > 0) + N2 = ndi.binary_dilation(out_labels == 2, structure=struc) & (comp3 > 0) + + cnt1 = np.bincount(comp3[N1], minlength=n3 + 1) + cnt2 = np.bincount(comp3[N2], minlength=n3 + 1) + + assign = np.zeros(n3 + 1, dtype=np.int16) # 0=undecided, 1 or 2 otherwise + assign[cnt1 > cnt2] = 1 + assign[cnt2 > cnt1] = 2 + undec = np.where(assign[1:] == 0)[0] + 1 + log( + f"[touching] maj→1={int((assign==1).sum())}, maj→2={int((assign==2).sum())}, ties={len(undec)} | {perf_counter()-t1:.3f}s" + ) + + if ( + len(undec) > 0 + and (seedsA is not None) + and (seedsB is not None) + and len(seedsA) + and len(seedsB) + ): + t2 = perf_counter() + sA = np.zeros_like(out_labels, bool) + sA[tuple(np.array(seedsA).T)] = True + sB = np.zeros_like(out_labels, bool) + sB[tuple(np.array(seedsB).T)] = True + dA = _compute_edt(~sA, sampling, log=log, tag="split:EDT(dA)") + dB = _compute_edt(~sB, sampling, log=log, tag="split:EDT(dB)") + closer2 = (dB < dA) & (comp3 > 0) + + pref2 = np.bincount(comp3[closer2], minlength=n3 + 1) + total = np.bincount(comp3[comp3 > 0], minlength=n3 + 1) + + tie_ids = np.array(undec, dtype=int) + choose2 = pref2[tie_ids] > (total[tie_ids] - pref2[tie_ids]) + assign[tie_ids[choose2]] = 2 + assign[tie_ids[~choose2]] = 1 + log( + f"[touching] tie-break EDT done: to2={int(choose2.sum())}, to1={int((~choose2).sum())} | {perf_counter()-t2:.3f}s" + ) + + moved1 = moved2 = 0 + if (assign == 1).any(): + mask1 = assign[comp3] == 1 + moved1 = int(mask1.sum()) + out_labels[mask1] = 1 + if (assign == 2).any(): + mask2 = assign[comp3] == 2 + moved2 = int(mask2.sum()) + out_labels[mask2] = 2 + + log( + f"[touching] reassigned 3→1: {moved1}, 3→2: {moved2} | total {perf_counter()-t0:.3f}s" + ) + return moved1, moved2 + + # ---------- begin ---------- + t0 = perf_counter() + log(f"[init] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}") + log(f"[init] input volume shape: {binary_sv.shape}") + + # Convert input volumes and sampling into internal ZYX + sv_zyx, _ = _to_internal_zyx_volume(binary_sv, vol_order) + sampling = _to_zyx_sampling(voxel_size, vox_order) + log(f"[init] internal shape (z,y,x): {sv_zyx.shape}") + log(f"[init] sampling (z,y,x): {sampling}") + + # SNAP seeds to mask using the same KDTree-based method + A_all = _seeds_to_zyx(seeds_a, seed_order) + B_all = _seeds_to_zyx(seeds_b, seed_order) + log("[snap] snapping seeds to segment mask...") + + snap_cfg = dict( + use_boundary=True, + erosion_iters=1, + downsample=True, + downsample_mode="random", + downsample_target=50_000, + method=snap_method, # compatibility key + ) + if snap_kwargs is not None: + snap_cfg.update(snap_kwargs) + + def _snap_ZYX(pts_zyx, tagname): + if pts_zyx.size == 0: + return np.empty((0, 3), dtype=int) + # Convert ZYX -> XYZ for snapper + pts_xyz = pts_zyx[:, [2, 1, 0]] + snapped_xyz = snap_seeds_to_segment( + pts_xyz, + mask=sv_zyx, + mask_order="zyx", + voxel_size=( + sampling[2], + sampling[1], + sampling[0], + ), # convert ZYX→XYZ spacing + log=log, + tag=tagname, + **snap_cfg, + ) + return snapped_xyz[:, [2, 1, 0]] + + A = _snap_ZYX(A_all, "A@snap") + B = _snap_ZYX(B_all, "B@snap") + log(f"[seeds] A={len(A)}, B={len(B)}") + + out_zyx = np.zeros_like(sv_zyx, dtype=np.int16) + if A.size == 0 or B.size == 0 or not np.any(sv_zyx): + log("[seeds] missing seeds or empty SV; returning label=1 for entire SV") + out_zyx[sv_zyx] = 1 + return _from_internal_zyx_volume(out_zyx, vol_order) + + # Tight bbox ROI around mask with halo + t_bbox = perf_counter() + Z, Y, X = sv_zyx.shape + coords = np.argwhere(sv_zyx) + z0, y0, x0 = coords.min(0) + z1, y1, x1 = coords.max(0) + 1 + z0h = max(z0 - halo, 0) + y0h = max(y0 - halo, 0) + x0h = max(x0 - halo, 0) + z1h = min(z1 + halo, Z) + y1h = min(y1 + halo, Y) + x1h = min(x1 + halo, X) + sv = sv_zyx[z0h:z1h, y0h:y1h, x0h:x1h] + A_roi = A - np.array([z0h, y0h, x0h]) + B_roi = B - np.array([z0h, y0h, x0h]) + log( + f"[crop] ROI shape (internal): {sv.shape} (halo {halo}) | {perf_counter()-t_bbox:.3f}s" + ) + + # Build travel cost via EDT (Seung-Lab edt if available) + t1 = perf_counter() + dist = _compute_edt(sv, sampling, log=log, tag="split:EDT(mask)") + distn = dist / dist.max() if dist.max() > 0 else dist + eps = 1e-6 + speed = np.clip(distn ** max(gamma_neck, 0.0), eps, 1.0) + travel_cost = np.full_like(speed, 1e12, dtype=float) + travel_cost[sv] = 1.0 / speed[sv] + log( + f"[speed] EDT + speed map | {perf_counter()-t1:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Optional downsample for geodesic + use_ds = downsample_geodesic is not None + if use_ds: + dz, dy, dx = map(int, downsample_geodesic) + log(f"[geodesic] downsample grid: {downsample_geodesic}") + cost_ds = travel_cost[::dz, ::dy, ::dx] + mask_ds = sv[::dz, ::dy, ::dx] + sampling_ds = (sampling[0] * dz, sampling[1] * dy, sampling[2] * dx) + + def _to_ds(pts): + pts = (np.asarray(pts, int) // np.array([dz, dy, dx])).astype(int) + Zs, Ys, Xs = mask_ds.shape + keep = [] + for z, y, x in pts: + if 0 <= z < Zs and 0 <= y < Ys and 0 <= x < Xs and mask_ds[z, y, x]: + keep.append((z, y, x)) + return keep + + A_sub = _to_ds(A_roi) + B_sub = _to_ds(B_roi) + log(f"[geodesic] seeds on DS grid: A={len(A_sub)}, B={len(B_sub)}") + if len(A_sub) == 0 or len(B_sub) == 0: + log("[geodesic] DS removed all seeds; falling back to full-res") + use_ds = False + if not use_ds: + cost_ds = travel_cost + mask_ds = sv + sampling_ds = sampling + A_sub = [tuple(p) for p in A_roi.tolist()] + B_sub = [tuple(p) for p in B_roi.tolist()] + + # Geodesic arrival times + t2 = perf_counter() + mcpA = MCP_Geometric(cost_ds, sampling=sampling_ds) + TA, _ = mcpA.find_costs(A_sub, find_all_ends=False) + mcpB = MCP_Geometric(cost_ds, sampling=sampling_ds) + TB, _ = mcpB.find_costs(B_sub, find_all_ends=False) + TA = np.where(mask_ds, TA, np.inf) + TB = np.where(mask_ds, TB, np.inf) + log( + f"[geodesic] TA/TB computed | {perf_counter()-t2:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Narrow band + t3 = perf_counter() + finite = np.isfinite(TA) & np.isfinite(TB) & mask_ds + denom = TA + TB + 1e-12 + reldiff = np.zeros_like(TA) + reldiff[finite] = np.abs(TA[finite] - TB[finite]) / denom[finite] + band = finite & (reldiff <= narrow_band_rel) + if nb_dilate > 0: + band = ndi.binary_dilation(band, structure=ball(nb_dilate)) & mask_ds + if band.sum() < 64: + band = mask_ds.copy() + log("[band] tiny band -> using full ROI on current grid") + log( + f"[band] voxels: {int(band.sum())} | {perf_counter()-t3:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Proximity-boosted labeling + t4 = perf_counter() + denomA = 1.0 + k_prox * np.exp(-lambda_prox * np.clip(TB, 0, np.inf)) + denomB = 1.0 + k_prox * np.exp(-lambda_prox * np.clip(TA, 0, np.inf)) + CA = TA / denomA + CB = TB / denomB + sub_labels_ds = np.zeros_like(mask_ds, dtype=np.int16) + sub_labels_ds[(CA <= CB) & band] = 1 + sub_labels_ds[(CB < CA) & band] = 2 + outer = mask_ds & (sub_labels_ds == 0) + sub_labels_ds[(TA <= TB) & outer] = 1 + sub_labels_ds[(TB < TA) & outer] = 2 + for z, y, x in A_sub: + sub_labels_ds[z, y, x] = 1 + for z, y, x in B_sub: + sub_labels_ds[z, y, x] = 2 + log( + f"[label] DS labeling done | {perf_counter()-t4:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Upsample if needed + if use_ds: + sub_labels = _upsample_labels(sub_labels_ds, (dz, dy, dx), sv.shape) + sub_labels[~sv] = 0 + for z, y, x in A_roi: + sub_labels[z, y, x] = 1 + for z, y, x in B_roi: + sub_labels[z, y, x] = 2 + log(f"[label] upsampled DS→full ROI") + else: + sub_labels = sub_labels_ds + + # Writeback + out_zyx[sv_zyx] = 1 + out_zyx[z0h:z1h, y0h:y1h, x0h:x1h][sub_labels == 1] = 1 + out_zyx[z0h:z1h, y0h:y1h, x0h:x1h][sub_labels == 2] = 2 + log("[writeback] labels written to full volume") + + # Enforce single CC per label + if enforce_single_cc: + keptA, movedA = _enforce_single_component( + out_zyx, 1, A, allow3=allow_third_label + ) + keptB, movedB = _enforce_single_component( + out_zyx, 2, B, allow3=allow_third_label + ) + log( + f"[single-cc] label1 kept {keptA}, moved {movedA} -> 3; label2 kept {keptB}, moved {movedB} -> 3" + ) + + # Resolve 3-touching + moved1, moved2 = _resolve_label3_touching_vectorized(out_zyx, A, B, sampling) + if moved1 or moved2: + if enforce_single_cc: + keptA, movedA = _enforce_single_component( + out_zyx, 1, A, allow3=allow_third_label + ) + keptB, movedB = _enforce_single_component( + out_zyx, 2, B, allow3=allow_third_label + ) + log( + f"[single-cc 2nd] label1 kept {keptA}, moved {movedA}; label2 kept {keptB}, moved {movedB}" + ) + + # Final check + for lab in (1, 2): + _, ncomp = _cc_label_26(out_zyx == lab) + if ncomp > 1: + msg = f"[check] label {lab} has {ncomp} connected components" + if raise_if_multi_cc: + raise ValueError(msg) + else: + log(msg) + + log(f"[done] total elapsed {perf_counter()-t0:.3f}s") + return _from_internal_zyx_volume(out_zyx, vol_order) + + +def build_kdtrees_by_label( + vol: np.ndarray, + *, + background: int = 0, + leafsize: int = 16, + balanced_tree: bool = True, + compact_nodes: bool = True, + min_points: int = 1, + dtype: np.dtype = np.float32, +) -> Tuple[Dict[int, cKDTree], Dict[int, int]]: + """ + Build a cKDTree of voxel coordinates for every unique (non-background) label in a 3D volume. + + Parameters + ---------- + vol : np.ndarray + 3D label volume (e.g., shape (Z, Y, X)). Can be any integer dtype (incl. uint64). + background : int, default 0 + Label treated as background and skipped. + leafsize : int, default 16 + Passed to cKDTree (larger can be faster for queries on large trees). + balanced_tree : bool, default True + Passed to cKDTree. + compact_nodes : bool, default True + Passed to cKDTree. + min_points : int, default 1 + Skip labels with fewer than this many voxels. + dtype : np.dtype, default np.float32 + Coordinate dtype used to build the trees (lower memory than float64). + + Returns + ------- + trees : Dict[int, cKDTree] + Mapping label -> cKDTree built + from the (z, y, x) coordinates of that label’s voxels. + counts : Dict[int, int] + Mapping label -> number of voxels used to build the tree. + + Notes + ----- + - This runs in O(N log N) due to a single sort over N foreground voxels. + - Uses one pass over non-background voxels; avoids per-label boolean masking. + - Coordinates are (z, y, x) in voxel units. + """ + if vol.ndim != 3: + raise ValueError("`vol` must be a 3D array.") + Z, Y, X = vol.shape + + # Flatten once and select foreground voxels + flat = vol.ravel() + if background == 0: + nz = np.flatnonzero(flat) # fast path when background is 0 + else: + nz = np.flatnonzero(flat != background) + + if nz.size == 0: + return {}, {} + + # Labels of foreground voxels (kept as integer/uint64) + labels = flat[nz] + + # Coordinates for those voxels (computed once) + z, y, x = np.unravel_index(nz, (Z, Y, X)) + coords = np.column_stack((z, y, x)).astype(dtype, copy=False) + + # Group by label via sort (stable to preserve any incidental ordering) + order = np.argsort(labels, kind="mergesort") + labels_sorted = labels[order] + + # Find group boundaries (run-length encoding over sorted labels) + starts = np.flatnonzero(np.r_[True, labels_sorted[1:] != labels_sorted[:-1]]) + ends = np.r_[starts[1:], labels_sorted.size] + + trees: Dict[int, cKDTree] = {} + counts: Dict[int, int] = {} + + for s, e in zip(starts, ends): + lab = int(labels_sorted[s]) # Python int key (handles uint64 safely) + block = coords[order[s:e]] + n = block.shape[0] + if n < min_points: + continue + # cKDTree copies data into its own memory; no need to keep `block` afterwards. + trees[lab] = cKDTree( + block, + leafsize=leafsize, + balanced_tree=balanced_tree, + compact_nodes=compact_nodes, + ) + counts[lab] = n + + return trees, counts + + +def pairwise_min_distance_two_sets( + trees_a: Sequence[cKDTree], + trees_b: Sequence[cKDTree], + *, + max_distance: Optional[float] = None, + workers: int = -1, +) -> np.ndarray: + """ + Compute pairwise shortest distances between point sets represented by two lists + of cKDTrees. Result has shape (len(trees_a), len(trees_b)). + + Parameters + ---------- + trees_a, trees_b : sequences of cKDTree + Each tree encodes the (z,y,x) points for one segment. + max_distance : float or None + If None (default): compute exact min distances (dense, finite). + If set: compute within this cutoff using sparse_distance_matrix; pairs with + no neighbors within cutoff are set to np.inf. + workers : int + Parallelism for cKDTree.query (SciPy >= 1.6). -1 uses all cores. + + Returns + ------- + D : ndarray, shape (len(trees_a), len(trees_b)) + D[i,j] = min distance between any point in trees_a[i] and trees_b[j]. + If max_distance is not None, entries may be np.inf. + """ + A, B = len(trees_a), len(trees_b) + if A == 0 or B == 0: + return np.zeros((A, B), dtype=float) + + D = np.zeros((A, B), dtype=float) + + if max_distance is not None: + # Cutoff mode: faster when many pairs are far apart. + D.fill(np.inf) + for i in range(A): + ti = trees_a[i] + for j in range(B): + tj = trees_b[j] + s = ti.sparse_distance_matrix( + tj, max_distance, output_type="coo_matrix" + ) + if s.nnz > 0: + D[i, j] = float(s.data.min()) + return D + + # Exact mode: query points of the smaller tree into the larger tree (k=1) and take min. + for i in range(A): + ti = trees_a[i] + ni = ti.n + for j in range(B): + tj = trees_b[j] + nj = tj.n + if ni <= nj: + d, _ = tj.query(ti.data, k=1, workers=workers) + else: + d, _ = ti.query(tj.data, k=1, workers=workers) + # d can be scalar if one tree has 1 point; np.min handles both + D[i, j] = float(np.min(d)) + return D + + +def split_supervoxel_helper( + binary_seg: np.ndarray, + source_coords: np.ndarray, + sink_coords: np.ndarray, + voxel_size: tuple, + verbose: bool = False, +): + voxel_size = np.array(voxel_size) + downsample = voxel_size.max() // voxel_size + + # 1) Connect seed teams first + A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( + binary_seg, + source_coords, + sink_coords, + voxel_size=voxel_size, + downsample=downsample, + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + snap_method="kdtree", + snap_kwargs=dict( + use_boundary=False, # disables boundary-only snapping for maximum safety + downsample=False, # avoids losing candidates + method="kdtree", + ), + verbose=verbose, + ) + if not (okA and okB): + raise RuntimeError( + "In-mask connection failed for at least one team; skipping split." + ) + + # 2) Run the corridor-free splitter with same snapping settings + return split_supervoxel_growing( + binary_seg, + A_aug, + B_aug, + voxel_size=voxel_size, + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + halo=1, + gamma_neck=1.6, + narrow_band_rel=0.08, + nb_dilate=1, + downsample_geodesic=(1, 2, 2), + enforce_single_cc=True, + raise_if_seed_split=True, + raise_if_multi_cc=True, + verbose=verbose, + snap_method="kdtree", + snap_kwargs=dict( + use_boundary=False, # match the connector for consistency + downsample=False, + method="kdtree", + ), + ) diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py new file mode 100644 index 000000000..bb50505b0 --- /dev/null +++ b/pychunkedgraph/graph/edits_sv.py @@ -0,0 +1,439 @@ +""" +Manage new supervoxels after a supervoxel split. +""" + +from functools import reduce +import logging +import multiprocessing as mp +from typing import Callable, Iterable +from datetime import datetime +from collections import defaultdict, deque + +import fastremap +import numpy as np +from tqdm import tqdm +from pychunkedgraph.graph import ChunkedGraph, cache as cache_utils +from pychunkedgraph.graph.attributes import Connectivity +from pychunkedgraph.graph.chunks.utils import chunks_overlapping_bbox, get_neighbors +from pychunkedgraph.graph.cutting_sv import ( + build_kdtrees_by_label, + pairwise_min_distance_two_sets, + split_supervoxel_helper, +) +from pychunkedgraph.graph.attributes import Hierarchy, OperationLogs +from pychunkedgraph.graph.edges import Edges +from pychunkedgraph.graph.types import empty_2d +from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.graph.utils import get_local_segmentation +from pychunkedgraph.graph.utils.serializers import serialize_uint64 +from pychunkedgraph.io.edges import get_chunk_edges + + +def _get_whole_sv( + cg: ChunkedGraph, node: basetypes.NODE_ID, min_coord, max_coord +) -> set: + cx_edges = [empty_2d] + explored_chunks = set() + explored_nodes = set([node]) + queue = deque([node]) + + while len(queue) > 0: + vertex = queue.popleft() + chunk = cg.get_chunk_coordinates(vertex) + chunks = get_neighbors(chunk, min_coord=min_coord, max_coord=max_coord) + + unexplored_chunks = [] + for _chunk in chunks: + if tuple(_chunk) not in explored_chunks: + unexplored_chunks.append(tuple(_chunk)) + + edges = get_chunk_edges(cg.meta.data_source.EDGES, unexplored_chunks) + explored_chunks.update(unexplored_chunks) + _cx_edges = edges["cross"].get_pairs() + cx_edges.append(_cx_edges) + _cx_edges = np.concatenate(cx_edges) + + mask = _cx_edges[:, 0] == vertex + neighbors = _cx_edges[mask][:, 1] + + neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) + min_mask = (neighbor_coords >= min_coord).all(axis=1) + max_mask = (neighbor_coords < max_coord).all(axis=1) + neighbors = neighbors[min_mask & max_mask] + + for neighbor in neighbors: + if neighbor in explored_nodes: + continue + explored_nodes.add(neighbor) + queue.append(neighbor) + return explored_nodes + + +def _update_chunk(args): + """ + For a chunk that overlaps bounding box for supervoxel split, + If chunk contains mask for the split supervoxel, + return indices of mask, old and new supervoxel IDs from this chunk. + """ + graph_id, chunk_coord, chunk_bbox, seg, result_seg, bb_start = args + cg = ChunkedGraph(graph_id=graph_id) + x, y, z = chunk_coord + chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) + + # TODO: remove these 3 lines, testing only + rr = cg.range_read_chunk(chunk_id) + max_node_id = max(rr.keys()) + cg.id_client.set_max_node_id(chunk_id, max_node_id) + + _s, _e = chunk_bbox - bb_start + og_chunk_seg = seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] + chunk_seg = result_seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] + + labels = fastremap.unique(chunk_seg[chunk_seg != 0]) + if labels.size < 2: + return None + + _indices = [] + _old_values = [] + _new_values = [] + for _id in labels: + _mask = chunk_seg == _id + if np.any(_mask): + _idx = np.unravel_index(np.flatnonzero(_mask)[0], og_chunk_seg.shape) + _og_value = og_chunk_seg[_idx] + _index = np.argwhere(_mask) + _indices.append(_index) + _ones = np.ones(len(_index), dtype=basetypes.NODE_ID) + _old_values.append(_ones * _og_value) + _new_values.append(_ones * cg.id_client.create_node_id(chunk_id)) + + _indices = np.concatenate(_indices) + (chunk_bbox[0] - bb_start) + _old_values = np.concatenate(_old_values) + _new_values = np.concatenate(_new_values) + return (_indices, _old_values, _new_values) + + +def _voxel_crop(bbs, bbe, bbs_, bbe_): + xS, yS, zS = bbs - bbs_ + xE, yE, zE = (None if i == 0 else -1 for i in bbe_ - bbe) + voxel_overlap_crop = np.s_[xS:xE, yS:yE, zS:zE] + logging.info(f"voxel_overlap_crop: {voxel_overlap_crop}") + return voxel_overlap_crop + + +def _parse_results(results, seg, bbs, bbe): + old_new_map = defaultdict(set) + for result in results: + if result: + indexer, old_values, new_values = result + seg[tuple(indexer.T)] = new_values + for old_sv, new_sv in zip(old_values, new_values): + old_new_map[old_sv].add(new_sv) + + assert np.all(seg.shape == bbe - bbs), f"{seg.shape} != {bbe - bbs}" + slices = tuple(slice(start, end) for start, end in zip(bbs, bbe)) + (slice(None),) + logging.info(f"slices {slices}") + return seg, old_new_map, slices + + +def _get_new_edges( + edges_info: tuple, + sv_ids: np.ndarray, + old_new_map: dict, + distances: np.ndarray, + dist_vec: Callable, + new_dist_vec: Callable, +): + THRESHOLD = 10 + new_edges, new_affs, new_areas = [], [], [] + edges, affinities, areas = edges_info + + for old, new in old_new_map.items(): + logging.info(f"old and new {old, new}") + new_ids = np.array(list(new), dtype=basetypes.NODE_ID) + edges_m = np.any(edges == old, axis=1) + selected_edges = edges[edges_m] + sel_m = selected_edges != old + assert np.all(np.sum(sel_m, axis=1) == 1) + + partners = selected_edges[sel_m] + active_m = np.isin(partners, sv_ids) + + logging.info(f"sv_ids: {np.sum(sv_ids > 0)}") + logging.info(f"edges: {edges.shape} {np.sum(edges_m)} {np.sum(sel_m)}") + logging.info(f"selected_edges: {selected_edges.shape}") + + # inactive + for new_id in new_ids: + _a = [[new_id] * np.sum(~active_m), partners[~active_m]] + new_edges.extend(np.array(_a, dtype=np.uint64).T) + new_affs.extend(affinities[edges_m][np.any(sel_m, axis=1)][~active_m]) + new_areas.extend(areas[edges_m][np.any(sel_m, axis=1)][~active_m]) + + # active + active_partners_ = partners[active_m] + active_affs_ = affinities[edges_m][np.any(sel_m, axis=1)][active_m] + active_areas_ = areas[edges_m][np.any(sel_m, axis=1)][active_m] + + logging.info(f"partners: {partners.shape} {active_partners_.shape}") + + active_partners = [] + active_affs = [] + active_areas = [] + for i in range(len(active_partners_)): + remapped_ = old_new_map.get(active_partners_[i], [active_partners_[i]]) + active_partners.extend(remapped_) + active_affs.extend([active_affs_[i]] * len(remapped_)) + active_areas.extend([active_areas_[i]] * len(remapped_)) + + logging.info(f"new_ids, active_partners: {new_ids, len(active_partners)}") + logging.info(f"new_dist_vec(new_ids): {new_dist_vec(new_ids)}") + logging.info(f"dist_vec(active_partners): {dist_vec(active_partners)}") + distances_ = distances[new_dist_vec(new_ids)][:, dist_vec(active_partners)].T + for i, _ in enumerate(active_partners): + new_ids_ = new_ids[distances_[i] < THRESHOLD] + if len(new_ids_): + _a = [new_ids_, [active_partners[i]] * len(new_ids_)] + new_edges.extend(np.array(_a, dtype=np.uint64).T) + new_affs.extend([active_affs[i]] * len(new_ids_)) + new_areas.extend([active_areas[i]] * len(new_ids_)) + else: + close_new_sv_id = new_ids[np.argmin(distances_[i])] + _a = [close_new_sv_id, active_partners[i]] + new_edges.append(np.array(_a, dtype=np.uint64)) + new_affs.append(active_affs[i]) + new_areas.append(active_areas[i]) + + # edges between split fragments + for i in range(len(new_ids)): + for j in range(i + 1, len(new_ids)): # includes no selfedges + _a = [new_ids[i], new_ids[j]] + new_edges.append(np.array(_a, dtype=np.uint64)) + new_affs.append(0.001) + new_areas.append(0) + + affinites = np.array(new_affs, dtype=basetypes.EDGE_AFFINITY) + areas = np.array(new_areas, dtype=basetypes.EDGE_AREA) + edges = np.array(new_edges, dtype=basetypes.NODE_ID) + edges, idx = np.unique(edges, return_index=True, axis=0) + return edges, affinites[idx], areas[idx] + + +def _update_edges( + cg: ChunkedGraph, + sv_ids: np.ndarray, + root_id: basetypes.NODE_ID, + bbox: np.ndarray, + new_seg: np.ndarray, + old_new_map: dict, +): + old_new_map = dict(old_new_map) + kdtrees, _ = build_kdtrees_by_label(new_seg) + distance_map = dict(zip(kdtrees.keys(), np.arange(len(kdtrees)))) + dist_vec = np.vectorize(distance_map.get) + + _, edges_tuple = cg.get_subgraph(root_id, bbox, bbox_is_coordinate=True) + edges_ = reduce(lambda x, y: x + y, edges_tuple, Edges([], [])) + + edges = edges_.get_pairs() + affinities = edges_.affinities + areas = edges_.areas + + edges = np.sort(edges, axis=1) + _, edges_idx = np.unique(edges, axis=0, return_index=True) + edges_idx = edges_idx[edges[edges_idx, 0] != edges[edges_idx, 1]] + + edges = edges[edges_idx] + affinities = affinities[edges_idx] + areas = areas[edges_idx] + logging.info(f"edges.shape, affinities.shape {edges.shape, affinities.shape}") + + new_ids = np.array(list(set.union(*old_new_map.values())), dtype=basetypes.NODE_ID) + new_kdtrees = [kdtrees[k] for k in new_ids] + new_disance_map = dict(zip(new_ids, np.arange(len(new_ids)))) + new_dist_vec = np.vectorize(new_disance_map.get) + distances = pairwise_min_distance_two_sets(new_kdtrees, list(kdtrees.values())) + return _get_new_edges( + (edges, affinities, areas), + sv_ids, + old_new_map, + distances, + dist_vec, + new_dist_vec, + ) + + +def _add_new_edges(cg: ChunkedGraph, edges_tuple: tuple, time_stamp: datetime = None): + edges_, affinites_, areas_ = edges_tuple + logging.info(f"new edges: {edges_.shape}") + + nodes = fastremap.unique(edges_) + chunks = cg.get_chunk_ids_from_node_ids(cg.get_parents(nodes)) + node_chunks = dict(zip(nodes, chunks)) + + edges = np.r_[edges_, edges_[:, ::-1]] + affinites = np.r_[affinites_, affinites_] + areas = np.r_[areas_, areas_] + + rows = [] + chunks_arr = fastremap.remap(edges, node_chunks) + for chunk_id in np.unique(chunks): + val_dict = {} + mask = chunks_arr[:, 0] == chunk_id + val_dict[Connectivity.SplitEdges] = edges[mask] + val_dict[Connectivity.Affinity] = affinites[mask] + val_dict[Connectivity.Area] = areas[mask] + rows.append( + cg.client.mutate_row( + serialize_uint64(chunk_id, fake_edges=True), + val_dict=val_dict, + time_stamp=time_stamp, + ) + ) + logging.info(f"writing {edges[mask].shape} edges to {chunk_id}") + return rows + + +def split_supervoxel( + cg: ChunkedGraph, + sv_id: basetypes.NODE_ID, + source_coords: np.ndarray, + sink_coords: np.ndarray, + operation_id: int, + verbose: bool = True, + time_stamp: datetime = None, +) -> dict[int, set]: + """ + Lookups coordinates of given supervoxel in segmentation. + Finds its counterparts split by chunk boundaries and splits them as a whole. + Updates the segmentation with new IDs. + """ + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + chunk_size = cg.meta.graph_config.CHUNK_SIZE + _coords = np.concatenate([source_coords, sink_coords]) + _padding = np.array([64] * 3) / cg.meta.resolution + + bbs = np.clip((np.min(_coords, 0) - _padding).astype(int), vol_start, vol_end) + bbe = np.clip((np.max(_coords, 0) + _padding).astype(int), vol_start, vol_end) + chunk_min, chunk_max = bbs // chunk_size, np.ceil(bbe / chunk_size).astype(int) + bbs, bbe = chunk_min * chunk_size, chunk_max * chunk_size + logging.info(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}") + logging.info(f"{chunk_size}; {_padding}; {(bbs, bbe)}; {(chunk_min, chunk_max)}") + + cut_supervoxels = _get_whole_sv(cg, sv_id, min_coord=chunk_min, max_coord=chunk_max) + supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) + logging.info(f"{sv_id} -> {cut_supervoxels}") + + # one voxel overlap for neighbors + bbs_ = np.clip(bbs - 1, vol_start, vol_end) + bbe_ = np.clip(bbe + 1, vol_start, vol_end) + seg = get_local_segmentation(cg.meta, bbs_, bbe_).squeeze() + binary_seg = np.isin(seg, supervoxel_ids) + logging.info(f"{seg.shape}; {binary_seg.shape}; {bbs, bbe}; {bbs_, bbe_}") + + voxel_overlap_crop = _voxel_crop(bbs, bbe, bbs_, bbe_) + split_result = split_supervoxel_helper( + binary_seg[voxel_overlap_crop], + source_coords - bbs, + sink_coords - bbs, + cg.meta.resolution, + verbose=verbose, + ) + logging.info(f"split_result: {split_result.shape}") + + chunks_bbox_map = chunks_overlapping_bbox(bbs, bbe, cg.meta.graph_config.CHUNK_SIZE) + tasks = [ + (cg.graph_id, *item, seg[voxel_overlap_crop], split_result, bbs) + for item in chunks_bbox_map.items() + ] + logging.info(f"tasks count: {len(tasks)}") + with mp.Pool() as pool: + results = [*tqdm(pool.imap_unordered(_update_chunk, tasks), total=len(tasks))] + seg_cropped = seg[voxel_overlap_crop].copy() + new_seg, old_new_map, slices = _parse_results(results, seg_cropped, bbs, bbe) + + seg_roots = seg.copy() + sv_ids = fastremap.unique(seg) + roots = cg.get_roots(sv_ids) + seg_roots = fastremap.remap(seg_roots, dict(zip(sv_ids, roots)), in_place=True) + + root = cg.get_root(sv_id) + logging.info(f"root {root}") + + seg_masked = seg.copy() + seg_masked[seg_roots != root] = 0 + sv_ids = fastremap.unique(seg_masked) + + seg_masked[voxel_overlap_crop] = new_seg + edges_tuple = _update_edges( + cg, sv_ids, root, np.array([bbs, bbe]), seg_masked, old_new_map + ) + + rows0 = copy_parents_and_add_lineage(cg, operation_id, old_new_map) + rows1 = _add_new_edges(cg, edges_tuple, time_stamp=time_stamp) + rows = rows0 + rows1 + logging.info(f"{operation_id}: writing {len(rows)} new rows") + + cg.client.write(rows) + cg.meta.ws_ocdbt[slices] = new_seg[..., np.newaxis] + return old_new_map, edges_tuple + + +def copy_parents_and_add_lineage( + cg: ChunkedGraph, + operation_id: int, + old_new_map: dict, +) -> list: + """ + Copy parents column from `old_id` to each of `new_ids`. + This makes it easy to get old hierarchy with `new_ids` using an older timestamp. + Link `old_id` and `new_ids` to create a lineage at supervoxel layer. + Returns a list of mutations to be persisted. + """ + result = [] + parents = set() + old_new_map = {k: list(v) for k, v in old_new_map.items()} + parent_cells_map = cg.client.read_nodes( + node_ids=list(old_new_map.keys()), properties=Hierarchy.Parent + ) + for old_id, new_ids in old_new_map.items(): + for new_id in new_ids: + val_dict = { + Hierarchy.FormerIdentity: np.array([old_id], dtype=basetypes.NODE_ID), + OperationLogs.OperationID: operation_id, + } + result.append(cg.client.mutate_row(serialize_uint64(new_id), val_dict)) + for cell in parent_cells_map[old_id]: + cache_utils.update(cg.cache.parents_cache, [new_id], cell.value) + parents.add(cell.value) + result.append( + cg.client.mutate_row( + serialize_uint64(new_id), + {Hierarchy.Parent: cell.value}, + time_stamp=cell.timestamp, + ) + ) + val_dict = {Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID)} + result.append(cg.client.mutate_row(serialize_uint64(old_id), val_dict)) + + children_cells_map = cg.client.read_nodes( + node_ids=list(parents), properties=Hierarchy.Child + ) + for parent, children_cells in children_cells_map.items(): + assert len(children_cells) == 1, children_cells + for cell in children_cells: + logging.info(f"{parent}: {cell.value}") + mask = np.isin(cell.value, list(old_new_map.keys())) + replace = np.concatenate([old_new_map[x] for x in cell.value[mask]]) + children = np.concatenate([cell.value[~mask], replace]) + logging.info(f"{parent}: {children}") + cg.cache.children_cache[parent] = children + result.append( + cg.client.mutate_row( + serialize_uint64(parent), + {Hierarchy.Child: children}, + time_stamp=cell.timestamp, + ) + ) + return result diff --git a/pychunkedgraph/graph/types.py b/pychunkedgraph/graph/types.py index f6d4395d9..fb7789cf1 100644 --- a/pychunkedgraph/graph/types.py +++ b/pychunkedgraph/graph/types.py @@ -7,7 +7,8 @@ empty_1d = np.empty(0, dtype=basetypes.NODE_ID) empty_2d = np.empty((0, 2), dtype=basetypes.NODE_ID) - +empty_affinities = np.empty(0, dtype=basetypes.EDGE_AFFINITY) +empty_areas = np.empty(0, dtype=basetypes.EDGE_AREA) """ An Agglomeration is syntactic sugar for representing diff --git a/pychunkedgraph/graph/utils/__init__.py b/pychunkedgraph/graph/utils/__init__.py index e69de29bb..c1d56e0fe 100644 --- a/pychunkedgraph/graph/utils/__init__.py +++ b/pychunkedgraph/graph/utils/__init__.py @@ -0,0 +1 @@ +from .generic import get_local_segmentation \ No newline at end of file diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py index d48da9cf2..0b5cf5c5c 100644 --- a/pychunkedgraph/graph/utils/generic.py +++ b/pychunkedgraph/graph/utils/generic.py @@ -151,3 +151,15 @@ def get_parents_at_timestamp(nodes, parents_ts_map, time_stamp, unique: bool = F except KeyError: skipped_nodes.append(node) return list(parents), skipped_nodes + + + +def get_local_segmentation(meta, bbox_start, bbox_end) -> np.ndarray: + result = None + xL, yL, zL = bbox_start + xH, yH, zH = bbox_end + if meta.ocdbt_seg: + result = meta.ws_ocdbt[xL:xH, yL:yH, zL:zH].read().result() + else: + result = meta.cv[xL:xH, yL:yH, zL:zH] + return result diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py index 5cbc3c061..43faf2160 100644 --- a/pychunkedgraph/graph/utils/id_helpers.py +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -10,6 +10,7 @@ import numpy as np from pychunkedgraph.graph import basetypes +from .generic import get_local_segmentation from ..meta import ChunkedGraphMeta from ..chunks import utils as chunk_utils @@ -140,10 +141,7 @@ def get_atomic_ids_from_coords( ] ) - local_sv_seg = meta.cv[ - bbox[0, 0] : bbox[1, 0], bbox[0, 1] : bbox[1, 1], bbox[0, 2] : bbox[1, 2] - ].squeeze() - + local_sv_seg = get_local_segmentation(meta, bbox[0], bbox[1]).squeeze() # limit get_roots calls to the relevant areas of the data lower_bs = np.floor( (np.array(coordinates_nm) - max_dist_nm) / np.array(meta.resolution) - bbox[0] diff --git a/pychunkedgraph/meshing/meshgen_utils.py b/pychunkedgraph/meshing/meshgen_utils.py index 2c150a785..8fbe237c3 100644 --- a/pychunkedgraph/meshing/meshgen_utils.py +++ b/pychunkedgraph/meshing/meshgen_utils.py @@ -1,19 +1,13 @@ import re -import multiprocessing as mp -from time import time -from typing import List -from typing import Dict -from typing import Tuple from typing import Sequence from functools import lru_cache import numpy as np -from cloudvolume import CloudVolume from cloudvolume.lib import Vec -from multiwrapper import multiprocessing_utils as mu from pychunkedgraph.graph.basetypes import NODE_ID # noqa from ..graph.types import empty_1d +from pychunkedgraph.graph.utils import get_local_segmentation def str_to_slice(slice_str: str): @@ -157,9 +151,7 @@ def get_json_info(cg): def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): - cv = CloudVolume(cg.meta.cv.cloudpath, mip=mip, fill_missing=True) mip_diff = mip - cg.meta.cv.mip - mip_chunk_size = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) / np.array( [2**mip_diff, 2**mip_diff, 1] ) @@ -175,11 +167,5 @@ def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): cg.meta.cv.mip_voxel_offset(mip), cg.meta.cv.mip_voxel_offset(mip) + cg.meta.cv.mip_volume_size(mip), ) - - ws_seg = cv[ - chunk_start[0] : chunk_end[0], - chunk_start[1] : chunk_end[1], - chunk_start[2] : chunk_end[2], - ].squeeze() - + ws_seg = get_local_segmentation(cg.meta, chunk_start, chunk_end).squeeze() return ws_seg diff --git a/requirements.in b/requirements.in index 3343123b1..2d8112537 100644 --- a/requirements.in +++ b/requirements.in @@ -14,6 +14,9 @@ pyyaml cachetools werkzeug tensorstore +edt +connected-components-3d +scikit-image # PyPI only: cloud-files>=6.0.0 From cff32d29fd7811736aab77ebfa19ce29b3a8c8df Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:40:37 +0000 Subject: [PATCH 04/16] feat(sv_split): sv split in frontend --- pychunkedgraph/app/segmentation/common.py | 95 ++++++++++++++++------- pychunkedgraph/graph/cutting.py | 20 ++--- pychunkedgraph/graph/exceptions.py | 18 +++++ pychunkedgraph/graph/operation.py | 11 ++- 4 files changed, 105 insertions(+), 39 deletions(-) diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 293b46981..61466dc54 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -9,16 +9,13 @@ import numpy as np import pandas as pd +import fastremap from flask import current_app, g, jsonify, make_response, request from pytz import UTC from pychunkedgraph import __version__ from pychunkedgraph.app import app_utils -from pychunkedgraph.graph import ( - attributes, - cutting, - segmenthistory, -) +from pychunkedgraph.graph import attributes, cutting, segmenthistory, ChunkedGraph from pychunkedgraph.graph import ( edges as cg_edges, ) @@ -26,6 +23,8 @@ exceptions as cg_exceptions, ) from pychunkedgraph.graph.analysis import pathing +from pychunkedgraph.graph.attributes import OperationLogs +from pychunkedgraph.graph.edits_sv import split_supervoxel from pychunkedgraph.graph.misc import get_contact_sites from pychunkedgraph.graph.operation import GraphEditOperation from pychunkedgraph.graph import basetypes @@ -396,7 +395,7 @@ def handle_merge(table_id, allow_same_segment_merge=False): current_app.operation_id = ret.operation_id if ret.new_root_ids is None: raise cg_exceptions.InternalServerError( - "Could not merge selected " "supervoxel." + f"{ret.operation_id}: Could not merge selected supervoxels." ) current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) @@ -410,24 +409,10 @@ def handle_merge(table_id, allow_same_segment_merge=False): ### SPLIT ---------------------------------------------------------------------- -def handle_split(table_id): - current_app.table_id = table_id - user_id = str(g.auth_user.get("id", current_app.user_id)) - - data = json.loads(request.data) - is_priority = request.args.get("priority", True, type=str2bool) - remesh = request.args.get("remesh", True, type=str2bool) - mincut = request.args.get("mincut", True, type=str2bool) - +def _get_sources_and_sinks(cg: ChunkedGraph, data): current_app.logger.debug(data) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id, skip_cache=True) node_idents = [] - node_ident_map = { - "sources": 0, - "sinks": 1, - } + node_ident_map = {"sources": 0, "sinks": 1} coords = [] node_ids = [] @@ -440,18 +425,74 @@ def handle_split(table_id): node_ids = np.array(node_ids, dtype=np.uint64) coords = np.array(coords) node_idents = np.array(node_idents) + + start = time.time() sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + current_app.logger.info(f"SV lookup took {time.time() - start}s.") current_app.logger.debug( {"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents} ) + source_ids = sv_ids[node_idents == 0] + sink_ids = sv_ids[node_idents == 1] + source_coords = coords[node_idents == 0] + sink_coords = coords[node_idents == 1] + return (source_ids, sink_ids, source_coords, sink_coords) + + +def handle_split(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + data = json.loads(request.data) + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + mincut = request.args.get("mincut", True, type=str2bool) + + cg = app_utils.get_cg(table_id, skip_cache=True) + sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) try: ret = cg.remove_edges( user_id=user_id, - source_ids=sv_ids[node_idents == 0], - sink_ids=sv_ids[node_idents == 1], - source_coords=coords[node_idents == 0], - sink_coords=coords[node_idents == 1], + source_ids=sources, + sink_ids=sinks, + source_coords=source_coords, + sink_coords=sink_coords, + mincut=mincut, + ) + except cg_exceptions.SupervoxelSplitRequiredError as e: + current_app.logger.info(e) + sources_remapped = fastremap.remap( + sources, + e.sv_remapping, + preserve_missing_labels=True, + in_place=False, + ) + sinks_remapped = fastremap.remap( + sinks, + e.sv_remapping, + preserve_missing_labels=True, + in_place=False, + ) + overlap_mask = np.isin(sources_remapped, sinks_remapped) + for sv_to_split in np.unique(sources_remapped[overlap_mask]): + _mask0 = sources_remapped[sources_remapped == sv_to_split] + _mask1 = sinks_remapped[sinks_remapped == sv_to_split] + split_supervoxel( + cg, + sv_to_split, + source_coords[_mask0], + sink_coords[_mask1], + e.operation_id, + ) + + sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) + ret = cg.remove_edges( + user_id=user_id, + source_ids=sources, + sink_ids=sinks, + source_coords=source_coords, + sink_coords=sink_coords, mincut=mincut, do_sanity_check=True, ) @@ -463,7 +504,7 @@ def handle_split(table_id): current_app.operation_id = ret.operation_id if ret.new_root_ids is None: raise cg_exceptions.InternalServerError( - "Could not split selected segment groups." + f"{ret.operation_id}: Could not split selected segment groups." ) current_app.logger.debug(("after split:", ret.new_root_ids)) diff --git a/pychunkedgraph/graph/cutting.py b/pychunkedgraph/graph/cutting.py index c5c24cf51..bd236397c 100644 --- a/pychunkedgraph/graph/cutting.py +++ b/pychunkedgraph/graph/cutting.py @@ -1,15 +1,11 @@ -import collections import fastremap import numpy as np import itertools -import logging import time import graph_tool import graph_tool.flow -from typing import Dict from typing import Tuple -from typing import Optional from typing import Sequence from typing import Iterable @@ -17,7 +13,7 @@ from pychunkedgraph.graph import basetypes from .utils.generic import get_bounding_box from .edges import Edges -from .exceptions import PreconditionError +from .exceptions import PreconditionError, SupervoxelSplitRequiredError from .exceptions import PostconditionError DEBUG_MODE = False @@ -116,6 +112,10 @@ def __init__( self.cross_chunk_edge_remapping, ) = merge_cross_chunk_edges_graph_tool(cg_edges, cg_affs) + # save this representative mapping for supervoxel splitting + # passed along with SupervoxelSplitRequiredError + self.sv_remapping = dict(complete_mapping) + dt = time.time() - time_start if logger is not None: logger.debug("Cross edge merging: %.2fms" % (dt * 1000)) @@ -233,9 +233,10 @@ def _augment_mincut_capacity(self): self.source_graph_ids, ) except AssertionError: - raise PreconditionError( + raise SupervoxelSplitRequiredError( "Paths between source or sink points irreparably overlap other labels from other side. " - "Check that labels are correct and consider spreading points out farther." + "Check that labels are correct and consider spreading points out farther.", + self.sv_remapping ) paths_e_s_no, paths_e_y_no, do_check = flatgraph.remove_overlapping_edges( @@ -586,11 +587,12 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): # but return a flag to return a message to the user illegal_split = True else: - raise PreconditionError( + raise SupervoxelSplitRequiredError( "Failed to find a cut that separated the sources from the sinks. " "Please try another cut that partitions the sets cleanly if possible. " "If there is a clear path between all the supervoxels in each set, " - "that helps the mincut algorithm." + "that helps the mincut algorithm.", + self.sv_remapping ) except IsolatingCutException as e: if self.split_preview: diff --git a/pychunkedgraph/graph/exceptions.py b/pychunkedgraph/graph/exceptions.py index f41cc2971..496f55e4f 100644 --- a/pychunkedgraph/graph/exceptions.py +++ b/pychunkedgraph/graph/exceptions.py @@ -83,3 +83,21 @@ class GatewayTimeout(ServerError): """Exception mapping a ``504 Gateway Timeout`` response.""" status_code = http_client.GATEWAY_TIMEOUT + + +class SupervoxelSplitRequiredError(ChunkedGraphError): + """ + Raised when supervoxel splitting is necessary. + Edit process should catch this error and retry after supervoxel has been split. + Saves remapping required for detecting which supervoxels need to be split. + """ + + def __init__( + self, + message: str, + sv_remapping: dict, + operation_id: int | None = None, + ): + super().__init__(message) + self.sv_remapping = sv_remapping + self.operation_id = operation_id diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 14e5f7715..5bf221e01 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -29,7 +29,7 @@ from pychunkedgraph.graph import serializers from .cache import CacheService from .cutting import run_multicut -from .exceptions import PreconditionError +from .exceptions import PreconditionError, SupervoxelSplitRequiredError from .exceptions import PostconditionError from .utils.generic import get_bounding_box as get_bbox from pychunkedgraph.graph import get_valid_timestamp @@ -460,6 +460,10 @@ def execute( new_lvl2_ids=new_lvl2_ids, old_root_ids=root_ids, ) + except SupervoxelSplitRequiredError as err: + raise SupervoxelSplitRequiredError( + str(err), err.sv_remapping, operation_id=lock.operation_id + ) from err except PreconditionError as err: self.cg.cache = None raise PreconditionError(err) from err @@ -889,9 +893,10 @@ def __init__( self.disallow_isolating_cut = disallow_isolating_cut self.do_sanity_check = do_sanity_check if np.any(np.isin(self.sink_ids, self.source_ids)): - raise PreconditionError( + raise SupervoxelSplitRequiredError( "Supervoxels exist in both sink and source, " - "try placing the points further apart." + "try placing the points further apart.", + None, ) ids = np.concatenate([self.source_ids, self.sink_ids]).astype(basetypes.NODE_ID) From 9ecdb025eb0d8e2fcc8d87cf27aaf24a38e979c8 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 26 Feb 2026 23:38:23 +0000 Subject: [PATCH 05/16] fix(sv_split): update multicut tests, add other tests --- pychunkedgraph/app/segmentation/common.py | 4 +- pychunkedgraph/graph/edits_sv.py | 9 +- .../tests/graph/test_chunks_utils.py | 98 +++ pychunkedgraph/tests/graph/test_exceptions.py | 24 + .../tests/graph/test_graph_build.py | 8 +- pychunkedgraph/tests/graph/test_meta.py | 73 +- pychunkedgraph/tests/graph/test_multicut.py | 2 +- pychunkedgraph/tests/graph/test_operation.py | 8 +- .../tests/graph/test_utils_generic.py | 28 + .../tests/graph/test_utils_id_helpers.py | 1 + pychunkedgraph/tests/test_cutting_sv.py | 729 ++++++++++++++++++ pychunkedgraph/tests/test_edits_sv.py | 220 ++++++ pychunkedgraph/tests/test_ocdbt.py | 141 ++++ 13 files changed, 1329 insertions(+), 16 deletions(-) create mode 100644 pychunkedgraph/tests/test_cutting_sv.py create mode 100644 pychunkedgraph/tests/test_edits_sv.py create mode 100644 pychunkedgraph/tests/test_ocdbt.py diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 61466dc54..0a9c1789f 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -476,8 +476,8 @@ def handle_split(table_id): ) overlap_mask = np.isin(sources_remapped, sinks_remapped) for sv_to_split in np.unique(sources_remapped[overlap_mask]): - _mask0 = sources_remapped[sources_remapped == sv_to_split] - _mask1 = sinks_remapped[sinks_remapped == sv_to_split] + _mask0 = sources_remapped == sv_to_split + _mask1 = sinks_remapped == sv_to_split split_supervoxel( cg, sv_to_split, diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index bb50505b0..4ac3a40f7 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -56,10 +56,11 @@ def _get_whole_sv( mask = _cx_edges[:, 0] == vertex neighbors = _cx_edges[mask][:, 1] - neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) - min_mask = (neighbor_coords >= min_coord).all(axis=1) - max_mask = (neighbor_coords < max_coord).all(axis=1) - neighbors = neighbors[min_mask & max_mask] + if len(neighbors) > 0: + neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) + min_mask = (neighbor_coords >= min_coord).all(axis=1) + max_mask = (neighbor_coords < max_coord).all(axis=1) + neighbors = neighbors[min_mask & max_mask] for neighbor in neighbors: if neighbor in explored_nodes: diff --git a/pychunkedgraph/tests/graph/test_chunks_utils.py b/pychunkedgraph/tests/graph/test_chunks_utils.py index 1d7764561..5ff14e417 100644 --- a/pychunkedgraph/tests/graph/test_chunks_utils.py +++ b/pychunkedgraph/tests/graph/test_chunks_utils.py @@ -125,9 +125,107 @@ def test_none(self, gen_graph): assert chunk_utils.normalize_bounding_box(graph.meta, None, False) is None +class TestChunksOverlappingBbox: + def test_single_chunk(self): + result = chunk_utils.chunks_overlapping_bbox( + bbox_min=[0, 0, 0], bbox_max=[63, 63, 63], chunk_size=[64, 64, 64] + ) + assert (0, 0, 0) in result + assert len(result) == 1 + + def test_multiple_chunks(self): + result = chunk_utils.chunks_overlapping_bbox( + bbox_min=[0, 0, 0], bbox_max=[128, 64, 64], chunk_size=[64, 64, 64] + ) + assert (0, 0, 0) in result + assert (1, 0, 0) in result + assert len(result) >= 2 + + def test_clipping(self): + result = chunk_utils.chunks_overlapping_bbox( + bbox_min=[10, 10, 10], bbox_max=[60, 60, 60], chunk_size=[64, 64, 64] + ) + assert (0, 0, 0) in result + bbox = result[(0, 0, 0)] + np.testing.assert_array_equal(bbox[0], [10, 10, 10]) + np.testing.assert_array_equal(bbox[1], [60, 60, 60]) + + def test_multi_chunk_clipping(self): + result = chunk_utils.chunks_overlapping_bbox( + bbox_min=[30, 0, 0], bbox_max=[100, 64, 64], chunk_size=[64, 64, 64] + ) + # chunk (0,0,0): min clipped to 30, max clipped to 64 + assert (0, 0, 0) in result + assert (1, 0, 0) in result + np.testing.assert_array_equal(result[(0, 0, 0)][0], [30, 0, 0]) + + +class TestGetNeighbors: + def test_inclusive(self): + neighbors = chunk_utils.get_neighbors([1, 1, 1], inclusive=True) + # 3^3 = 27 including the center + assert len(neighbors) == 27 + + def test_exclusive(self): + neighbors = chunk_utils.get_neighbors([1, 1, 1], inclusive=False) + assert len(neighbors) == 26 + # Center should not be in neighbors + has_center = any(np.array_equal(n, [1, 1, 1]) for n in neighbors) + assert not has_center + + def test_min_coord_clipping(self): + neighbors = chunk_utils.get_neighbors( + [0, 0, 0], inclusive=True, min_coord=[0, 0, 0] + ) + # Only non-negative coordinates; the 0,0,0 center has offsets going to -1,-1,-1 + for n in neighbors: + assert np.all(n >= 0) + + def test_max_coord_clipping(self): + neighbors = chunk_utils.get_neighbors( + [5, 5, 5], inclusive=True, max_coord=[5, 5, 5] + ) + for n in neighbors: + assert np.all(n <= 5) + + def test_corner_with_bounds(self): + neighbors = chunk_utils.get_neighbors( + [0, 0, 0], inclusive=True, min_coord=[0, 0, 0], max_coord=[2, 2, 2] + ) + # Should only include non-negative neighbors + for n in neighbors: + assert np.all(n >= 0) + assert np.all(n <= 2) + + +class TestGetL2ChunkIdsAlongBoundary: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=5) + coord_a = (0, 0, 0) + coord_b = (1, 0, 0) + chunk_utils.get_l2chunkids_along_boundary.cache_clear() + ids_a, ids_b = chunk_utils.get_l2chunkids_along_boundary( + graph.meta, 3, coord_a, coord_b + ) + assert len(ids_a) > 0 + assert len(ids_b) > 0 + + def test_with_padding(self, gen_graph): + graph = gen_graph(n_layers=5) + coord_a = (0, 0, 0) + coord_b = (1, 0, 0) + chunk_utils.get_l2chunkids_along_boundary.cache_clear() + ids_a, ids_b = chunk_utils.get_l2chunkids_along_boundary( + graph.meta, 3, coord_a, coord_b, padding=1 + ) + assert len(ids_a) > 0 + assert len(ids_b) > 0 + + class TestGetBoundingChildrenChunks: def test_basic(self, gen_graph): graph = gen_graph(n_layers=5) + chunk_utils.get_bounding_children_chunks.cache_clear() result = chunk_utils.get_bounding_children_chunks(graph.meta, 3, (0, 0, 0), 2) assert len(result) > 0 assert result.shape[1] == 3 diff --git a/pychunkedgraph/tests/graph/test_exceptions.py b/pychunkedgraph/tests/graph/test_exceptions.py index 82de4c063..1320360f4 100644 --- a/pychunkedgraph/tests/graph/test_exceptions.py +++ b/pychunkedgraph/tests/graph/test_exceptions.py @@ -19,6 +19,7 @@ ServerError, InternalServerError, GatewayTimeout, + SupervoxelSplitRequiredError, ) @@ -69,3 +70,26 @@ def test_internal_server_error(self): def test_gateway_timeout(self): assert GatewayTimeout.status_code == GATEWAY_TIMEOUT + + +class TestSupervoxelSplitRequiredError: + def test_inherits_chunkedgraph_error(self): + assert issubclass(SupervoxelSplitRequiredError, ChunkedGraphError) + + def test_stores_sv_remapping(self): + remap = {1: 10, 2: 20} + err = SupervoxelSplitRequiredError("split needed", remap) + assert err.sv_remapping == remap + assert str(err) == "split needed" + + def test_stores_operation_id(self): + err = SupervoxelSplitRequiredError("msg", {}, operation_id=42) + assert err.operation_id == 42 + + def test_operation_id_default_none(self): + err = SupervoxelSplitRequiredError("msg", {}) + assert err.operation_id is None + + def test_can_be_caught_as_chunkedgraph_error(self): + with pytest.raises(ChunkedGraphError): + raise SupervoxelSplitRequiredError("test", {1: 2}) diff --git a/pychunkedgraph/tests/graph/test_graph_build.py b/pychunkedgraph/tests/graph/test_graph_build.py index 575141abb..e773d1af3 100644 --- a/pychunkedgraph/tests/graph/test_graph_build.py +++ b/pychunkedgraph/tests/graph/test_graph_build.py @@ -49,7 +49,7 @@ def test_build_single_node(self, gen_graph): assert len(children) == 1 and children[0] == to_label(cg, 1, 0, 0, 0, 0) # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 1 + 1 + 1 + 1 + 1 + assert len(res.rows) == 1 + 1 + 2 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_edge(self, gen_graph): @@ -104,7 +104,7 @@ def test_build_single_edge(self, gen_graph): # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 2 + 1 + 1 + 1 + 1 + assert len(res.rows) == 2 + 1 + 2 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_across_edge(self, gen_graph): @@ -212,7 +212,7 @@ def test_build_single_across_edge(self, gen_graph): # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + 1 + assert len(res.rows) == 2 + 2 + 1 + 5 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_edge_and_single_across_edge(self, gen_graph): @@ -326,7 +326,7 @@ def test_build_single_edge_and_single_across_edge(self, gen_graph): # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + 1 + assert len(res.rows) == 3 + 2 + 1 + 5 + 1 + 1 @pytest.mark.timeout(120) def test_build_big_graph(self, gen_graph): diff --git a/pychunkedgraph/tests/graph/test_meta.py b/pychunkedgraph/tests/graph/test_meta.py index f94b7d792..999db2234 100644 --- a/pychunkedgraph/tests/graph/test_meta.py +++ b/pychunkedgraph/tests/graph/test_meta.py @@ -442,7 +442,9 @@ def test_ws_cv_redis_cached(self, mock_get_redis, mock_cv_cls): result = meta.ws_cv assert result is mock_cv_instance - mock_cv_cls.assert_called_once_with("gs://bucket/ws", info=cached_info) + mock_cv_cls.assert_called_once_with( + "gs://bucket/ws", info=cached_info, progress=False + ) @patch("pychunkedgraph.graph.meta.CloudVolume") @patch("pychunkedgraph.graph.meta.get_redis_connection") @@ -462,7 +464,7 @@ def test_ws_cv_redis_failure_fallback(self, mock_get_redis, mock_cv_cls): assert result is mock_cv_instance # Should have been called without info kwarg (fallback) - mock_cv_cls.assert_called_with("gs://bucket/ws") + mock_cv_cls.assert_called_with("gs://bucket/ws", progress=False) @patch("pychunkedgraph.graph.meta.CloudVolume") @patch("pychunkedgraph.graph.meta.get_redis_connection") @@ -485,7 +487,7 @@ def test_ws_cv_caches_to_redis(self, mock_get_redis, mock_cv_cls): assert result is mock_cv_instance # The fallback CloudVolume call (no info= kwarg) - mock_cv_cls.assert_called_with("gs://bucket/ws") + mock_cv_cls.assert_called_with("gs://bucket/ws", progress=False) # Should try to cache in redis mock_redis.set.assert_called_once() @@ -568,6 +570,71 @@ def test_bitmasks_cached_after_first_access(self): assert bm1 is bm2 +class TestOcdbtSeg: + """Test ocdbt_seg property.""" + + def test_ocdbt_seg_false_by_default(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + assert meta.ocdbt_seg is False + + def test_ocdbt_seg_true_from_custom_data(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) + assert meta.ocdbt_seg is True + + def test_ocdbt_seg_false_explicit(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": False}}) + assert meta.ocdbt_seg is False + + def test_ocdbt_seg_cached(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) + val1 = meta.ocdbt_seg + val2 = meta.ocdbt_seg + assert val1 is val2 + + def test_ws_ocdbt_asserts_when_not_ocdbt(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + with pytest.raises(AssertionError, match="ocdbt"): + _ = meta.ws_ocdbt + + @patch("pychunkedgraph.graph.meta.get_seg_source_and_destination_ocdbt") + def test_ws_ocdbt_returns_destination(self, mock_get_ocdbt): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) + + mock_src = MagicMock() + mock_dst = MagicMock() + mock_get_ocdbt.return_value = (mock_src, mock_dst) + + result = meta.ws_ocdbt + assert result is mock_dst + mock_get_ocdbt.assert_called_once_with("gs://bucket/ws") + + @patch("pychunkedgraph.graph.meta.get_seg_source_and_destination_ocdbt") + def test_ws_ocdbt_cached(self, mock_get_ocdbt): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) + + mock_dst = MagicMock() + mock_get_ocdbt.return_value = (MagicMock(), mock_dst) + + result1 = meta.ws_ocdbt + result2 = meta.ws_ocdbt + assert result1 is result2 + mock_get_ocdbt.assert_called_once() + + class TestLayerChunkBoundsComputed: """Test layer_chunk_bounds property computation.""" diff --git a/pychunkedgraph/tests/graph/test_multicut.py b/pychunkedgraph/tests/graph/test_multicut.py index 19507465e..87408a654 100644 --- a/pychunkedgraph/tests/graph/test_multicut.py +++ b/pychunkedgraph/tests/graph/test_multicut.py @@ -67,5 +67,5 @@ def test_path_augmented_multicut(self, sv_data): cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) assert cut_edges_aug.shape[0] == 350 - with pytest.raises(exceptions.PreconditionError): + with pytest.raises(exceptions.SupervoxelSplitRequiredError): run_multicut(edges, sv_sources, sv_sinks, path_augment=False) diff --git a/pychunkedgraph/tests/graph/test_operation.py b/pychunkedgraph/tests/graph/test_operation.py index db5878842..328ceb425 100644 --- a/pychunkedgraph/tests/graph/test_operation.py +++ b/pychunkedgraph/tests/graph/test_operation.py @@ -21,7 +21,11 @@ RedoOperation, UndoOperation, ) -from ...graph.exceptions import PreconditionError, PostconditionError +from ...graph.exceptions import ( + PreconditionError, + PostconditionError, + SupervoxelSplitRequiredError, +) from ...ingest.create.parent_layer import add_parent_chunk @@ -498,7 +502,7 @@ def test_split_self_loop_raises(self, gen_graph): def test_multicut_overlapping_ids_raises(self, gen_graph): """source_ids overlapping sink_ids should raise PreconditionError (line 872).""" cg, _, sv0, sv1 = _build_cross_chunk(gen_graph) - with pytest.raises(PreconditionError, match="both sink and source"): + with pytest.raises(SupervoxelSplitRequiredError, match="both sink and source"): MulticutOperation( cg, user_id="test_user", diff --git a/pychunkedgraph/tests/graph/test_utils_generic.py b/pychunkedgraph/tests/graph/test_utils_generic.py index 58444c838..248e3cb68 100644 --- a/pychunkedgraph/tests/graph/test_utils_generic.py +++ b/pychunkedgraph/tests/graph/test_utils_generic.py @@ -104,6 +104,34 @@ def test_unique(self): assert len(parents) == 1 +class TestGetLocalSegmentation: + def test_ocdbt_path(self): + from unittest.mock import MagicMock + from pychunkedgraph.graph.utils.generic import get_local_segmentation + + meta = MagicMock() + meta.ocdbt_seg = True + expected = np.ones((10, 10, 10), dtype=np.uint64) + mock_slice = MagicMock() + mock_slice.read.return_value.result.return_value = expected + meta.ws_ocdbt.__getitem__ = MagicMock(return_value=mock_slice) + + result = get_local_segmentation(meta, [0, 0, 0], [10, 10, 10]) + np.testing.assert_array_equal(result, expected) + + def test_cv_path(self): + from unittest.mock import MagicMock + from pychunkedgraph.graph.utils.generic import get_local_segmentation + + meta = MagicMock() + meta.ocdbt_seg = False + expected = np.ones((10, 10, 10), dtype=np.uint64) + meta.cv.__getitem__ = MagicMock(return_value=expected) + + result = get_local_segmentation(meta, [0, 0, 0], [10, 10, 10]) + np.testing.assert_array_equal(result, expected) + + class TestComputeIndicesPandas: def test_basic(self): data = np.array([1, 2, 1, 2, 3]) diff --git a/pychunkedgraph/tests/graph/test_utils_id_helpers.py b/pychunkedgraph/tests/graph/test_utils_id_helpers.py index f1b78c37e..ab4afa60d 100644 --- a/pychunkedgraph/tests/graph/test_utils_id_helpers.py +++ b/pychunkedgraph/tests/graph/test_utils_id_helpers.py @@ -178,6 +178,7 @@ def test_higher_layer_with_mock_cv(self): meta = MagicMock() meta.data_source.CV_MIP = 0 meta.resolution = np.array([8, 8, 40]) + meta.ocdbt_seg = False parent_id = np.uint64(100) sv1 = np.uint64(10) diff --git a/pychunkedgraph/tests/test_cutting_sv.py b/pychunkedgraph/tests/test_cutting_sv.py new file mode 100644 index 000000000..a2b29ac74 --- /dev/null +++ b/pychunkedgraph/tests/test_cutting_sv.py @@ -0,0 +1,729 @@ +"""Tests for pychunkedgraph.graph.cutting_sv""" + +import numpy as np +import pytest +from scipy.spatial import cKDTree + +from pychunkedgraph.graph.cutting_sv import ( + _cc_label_26, + _largest_component_id, + _to_zyx_sampling, + _to_internal_zyx_volume, + _from_internal_zyx_volume, + _seeds_to_zyx, + _seeds_from_zyx, + _extract_mask_boundary, + _downsample_points, + snap_seeds_to_segment, + _compute_edt, + _upsample_bool, + _upsample_labels, + build_kdtrees_by_label, + pairwise_min_distance_two_sets, + split_supervoxel_growing, + connect_both_seeds_via_ridge, + split_supervoxel_helper, +) + + +# ============================================================ +# Helper: create a simple 3D binary mask with two seed regions +# ============================================================ +def _make_dumbbell_mask(shape=(20, 30, 30)): + """ + Create a dumbbell-shaped mask: two blobs connected by a thin bridge. + Returns (mask, seeds_a_zyx, seeds_b_zyx) all in ZYX order. + """ + mask = np.zeros(shape, dtype=bool) + Z, Y, X = shape + # blob A: centered at (Z//2, Y//4, X//4) + cz, cy, cx = Z // 2, Y // 4, X // 4 + r = min(Z, Y, X) // 5 + for z in range(Z): + for y in range(Y): + for x in range(X): + if (z - cz) ** 2 + (y - cy) ** 2 + (x - cx) ** 2 <= r**2: + mask[z, y, x] = True + + # blob B: centered at (Z//2, 3*Y//4, 3*X//4) + cz2, cy2, cx2 = Z // 2, 3 * Y // 4, 3 * X // 4 + for z in range(Z): + for y in range(Y): + for x in range(X): + if (z - cz2) ** 2 + (y - cy2) ** 2 + (x - cx2) ** 2 <= r**2: + mask[z, y, x] = True + + # bridge between the two + mid_y = Y // 2 + mid_x = X // 2 + mask[cz - 1 : cz + 2, cy : cy2 + 1, mid_x - 1 : mid_x + 2] = True + + seeds_a = np.array([[cz, cy, cx]]) + seeds_b = np.array([[cz2, cy2, cx2]]) + return mask, seeds_a, seeds_b + + +# ============================================================ +# Tests: CC label and largest component +# ============================================================ +class TestCCLabel26: + def test_single_component(self): + mask = np.zeros((5, 5, 5), dtype=bool) + mask[1:4, 1:4, 1:4] = True + lbl, ncomp = _cc_label_26(mask) + assert ncomp == 1 + assert lbl[2, 2, 2] > 0 + assert lbl[0, 0, 0] == 0 + + def test_two_components(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[1:3, 1:3, 1:3] = True + mask[7:9, 7:9, 7:9] = True + lbl, ncomp = _cc_label_26(mask) + assert ncomp == 2 + assert lbl[2, 2, 2] != lbl[7, 7, 7] + + def test_empty(self): + mask = np.zeros((3, 3, 3), dtype=bool) + lbl, ncomp = _cc_label_26(mask) + assert ncomp == 0 + + def test_full(self): + mask = np.ones((4, 4, 4), dtype=bool) + lbl, ncomp = _cc_label_26(mask) + assert ncomp == 1 + + +class TestLargestComponentId: + def test_single_component(self): + lbl = np.zeros((5, 5, 5), dtype=np.int32) + lbl[1:4, 1:4, 1:4] = 1 + assert _largest_component_id(lbl) == 1 + + def test_two_components_picks_largest(self): + lbl = np.zeros((10, 10, 10), dtype=np.int32) + lbl[0:2, 0:2, 0:2] = 1 # 8 voxels + lbl[3:8, 3:8, 3:8] = 2 # 125 voxels + assert _largest_component_id(lbl) == 2 + + def test_all_background(self): + lbl = np.zeros((3, 3, 3), dtype=np.int32) + assert _largest_component_id(lbl) == 0 + + +# ============================================================ +# Tests: Order/utility helpers +# ============================================================ +class TestToZyxSampling: + def test_xyz_order(self): + result = _to_zyx_sampling((8.0, 8.0, 40.0), "xyz") + assert result == (40.0, 8.0, 8.0) + + def test_zyx_order(self): + result = _to_zyx_sampling((40.0, 8.0, 8.0), "zyx") + assert result == (40.0, 8.0, 8.0) + + def test_invalid_order_raises(self): + with pytest.raises(ValueError, match="vox_order"): + _to_zyx_sampling((1, 1, 1), "abc") + + +class TestToInternalZyxVolume: + def test_zyx_passthrough(self): + vol = np.zeros((3, 4, 5)) + result, transposed = _to_internal_zyx_volume(vol, "zyx") + assert result is vol + assert not transposed + + def test_xyz_transpose(self): + vol = np.zeros((5, 4, 3)) # X=5, Y=4, Z=3 + result, transposed = _to_internal_zyx_volume(vol, "xyz") + assert result.shape == (3, 4, 5) + assert transposed + + def test_invalid_raises(self): + with pytest.raises(ValueError, match="vol_order"): + _to_internal_zyx_volume(np.zeros((3, 3, 3)), "abc") + + +class TestFromInternalZyxVolume: + def test_zyx_passthrough(self): + vol = np.zeros((3, 4, 5)) + result = _from_internal_zyx_volume(vol, "zyx") + assert result is vol + + def test_xyz_transpose(self): + vol = np.zeros((3, 4, 5)) # Z=3, Y=4, X=5 + result = _from_internal_zyx_volume(vol, "xyz") + assert result.shape == (5, 4, 3) + + def test_invalid_raises(self): + with pytest.raises(ValueError, match="vol_order"): + _from_internal_zyx_volume(np.zeros((3, 3, 3)), "abc") + + +class TestSeedsToZyx: + def test_xyz_to_zyx(self): + seeds = np.array([[10, 20, 30]]) # x, y, z + result = _seeds_to_zyx(seeds, "xyz") + np.testing.assert_array_equal(result, [[30, 20, 10]]) + + def test_zyx_passthrough(self): + seeds = np.array([[30, 20, 10]]) # z, y, x + result = _seeds_to_zyx(seeds, "zyx") + np.testing.assert_array_equal(result, [[30, 20, 10]]) + + def test_invalid_raises(self): + with pytest.raises(ValueError, match="seed_order"): + _seeds_to_zyx(np.array([[1, 2, 3]]), "abc") + + +class TestSeedsFromZyx: + def test_xyz_output(self): + seeds = np.array([[30, 20, 10]]) # z, y, x + result = _seeds_from_zyx(seeds, "xyz") + np.testing.assert_array_equal(result, [[10, 20, 30]]) + + def test_zyx_passthrough(self): + seeds = np.array([[30, 20, 10]]) + result = _seeds_from_zyx(seeds, "zyx") + np.testing.assert_array_equal(result, [[30, 20, 10]]) + + def test_invalid_raises(self): + with pytest.raises(ValueError, match="seed_order"): + _seeds_from_zyx(np.array([[1, 2, 3]]), "abc") + + def test_roundtrip(self): + original = np.array([[10, 20, 30], [40, 50, 60]]) + zyx = _seeds_to_zyx(original, "xyz") + recovered = _seeds_from_zyx(zyx, "xyz") + np.testing.assert_array_equal(original, recovered) + + +# ============================================================ +# Tests: Snapping (KDTree-based) +# ============================================================ +class TestExtractMaskBoundary: + def test_basic_boundary(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[2:8, 2:8, 2:8] = True + boundary = _extract_mask_boundary(mask, erosion_iters=1) + # Interior should not be boundary + assert boundary[5, 5, 5] == False + # Edge should be boundary + assert boundary[2, 2, 2] == True + # Boundary must be subset of mask + assert np.all(boundary <= mask) + + def test_zero_erosion_returns_copy(self): + mask = np.ones((5, 5, 5), dtype=bool) + result = _extract_mask_boundary(mask, erosion_iters=0) + np.testing.assert_array_equal(result, mask) + + def test_thin_structure_all_boundary(self): + mask = np.zeros((5, 5, 5), dtype=bool) + mask[2, :, :] = True # single slice - all boundary + boundary = _extract_mask_boundary(mask, erosion_iters=1) + # For a single-voxel-thick structure, all voxels are boundary + assert boundary.sum() > 0 + + +class TestDownsamplePoints: + def test_stride(self): + pts = np.arange(30).reshape(10, 3) + result = _downsample_points(pts, mode="stride", stride=2) + assert len(result) == 5 + np.testing.assert_array_equal(result[0], pts[0]) + np.testing.assert_array_equal(result[1], pts[2]) + + def test_random(self): + rng = np.random.default_rng(42) + pts = np.arange(300).reshape(100, 3) + result = _downsample_points(pts, mode="random", target=10, rng=rng) + assert len(result) == 10 + + def test_random_target_larger_than_n(self): + pts = np.arange(15).reshape(5, 3) + result = _downsample_points(pts, mode="random", target=50) + assert len(result) == 5 + + def test_empty_returns_empty(self): + pts = np.empty((0, 3)) + result = _downsample_points(pts, mode="stride") + assert len(result) == 0 + + def test_invalid_mode_raises(self): + pts = np.arange(9).reshape(3, 3) + with pytest.raises(ValueError, match="downsample mode"): + _downsample_points(pts, mode="invalid") + + +class TestSnapSeedsToSegment: + def test_basic_snap(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[3:7, 3:7, 3:7] = True + seeds = np.array([[0.0, 0.0, 0.0]]) # far outside + result = snap_seeds_to_segment( + seeds, + mask, + mask_order="zyx", + use_boundary=False, + downsample=False, + ) + # Snapped seed should be on the mask + x, y, z = result[0] + assert mask[z, y, x] == True + + def test_seed_inside_mask(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[3:7, 3:7, 3:7] = True + seeds = np.array([[5.0, 5.0, 5.0]]) # inside the mask + result = snap_seeds_to_segment( + seeds, + mask, + mask_order="zyx", + use_boundary=False, + downsample=False, + ) + x, y, z = result[0] + assert mask[z, y, x] == True + + def test_with_boundary_and_downsample(self): + mask = np.zeros((20, 20, 20), dtype=bool) + mask[5:15, 5:15, 5:15] = True + seeds = np.array([[0.0, 0.0, 0.0]]) + result = snap_seeds_to_segment( + seeds, + mask, + mask_order="zyx", + use_boundary=True, + downsample=True, + downsample_mode="stride", + downsample_stride=2, + ) + x, y, z = result[0] + assert mask[z, y, x] == True + + def test_xyz_mask_order(self): + # mask_order='xyz' means shape is (X, Y, Z) + mask_xyz = np.zeros((10, 12, 8), dtype=bool) + mask_xyz[3:7, 3:9, 2:6] = True + seeds = np.array([[5.0, 6.0, 4.0]]) # xyz coords + result = snap_seeds_to_segment( + seeds, + mask_xyz, + mask_order="xyz", + use_boundary=False, + downsample=False, + ) + x, y, z = result[0] + assert mask_xyz[x, y, z] == True + + def test_empty_mask_raises(self): + mask = np.zeros((5, 5, 5), dtype=bool) + seeds = np.array([[2.0, 2.0, 2.0]]) + with pytest.raises(ValueError, match="no True voxels"): + snap_seeds_to_segment( + seeds, mask, mask_order="zyx", use_boundary=False, downsample=False + ) + + def test_non_3d_mask_raises(self): + mask = np.zeros((5, 5), dtype=bool) + seeds = np.array([[2.0, 2.0]]) + with pytest.raises(ValueError, match="3D"): + snap_seeds_to_segment(seeds, mask, mask_order="zyx") + + def test_return_index(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[5, 5, 5] = True + seeds = np.array([[0.0, 0.0, 0.0]]) + result, idx = snap_seeds_to_segment( + seeds, + mask, + mask_order="zyx", + use_boundary=False, + downsample=False, + return_index=True, + ) + assert idx.shape[0] == 1 + + def test_multiple_seeds(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[2:8, 2:8, 2:8] = True + seeds = np.array([[0.0, 0.0, 0.0], [9.0, 9.0, 9.0]]) + result = snap_seeds_to_segment( + seeds, + mask, + mask_order="zyx", + use_boundary=False, + downsample=False, + ) + assert result.shape == (2, 3) + for i in range(2): + x, y, z = result[i] + assert mask[z, y, x] == True + + def test_voxel_size_anisotropic(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[3:7, 3:7, 3:7] = True + seeds = np.array([[0.0, 0.0, 0.0]]) + result = snap_seeds_to_segment( + seeds, + mask, + mask_order="zyx", + voxel_size=(8.0, 8.0, 40.0), + use_boundary=False, + downsample=False, + ) + x, y, z = result[0] + assert mask[z, y, x] == True + + def test_invalid_mask_order(self): + mask = np.zeros((5, 5, 5), dtype=bool) + mask[2, 2, 2] = True + with pytest.raises(ValueError, match="mask_order"): + snap_seeds_to_segment( + np.array([[2, 2, 2]]), + mask, + mask_order="bad", + use_boundary=False, + downsample=False, + ) + + +# ============================================================ +# Tests: EDT +# ============================================================ +class TestComputeEdt: + def test_basic(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[3:7, 3:7, 3:7] = True + dist = _compute_edt(mask, (1.0, 1.0, 1.0)) + assert dist.shape == mask.shape + # Center should have highest distance + assert dist[5, 5, 5] > dist[3, 3, 3] + # Outside mask should be zero + assert dist[0, 0, 0] == 0.0 + + def test_anisotropic_sampling(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[3:7, 3:7, 3:7] = True + dist = _compute_edt(mask, (40.0, 8.0, 8.0)) + assert dist.shape == mask.shape + assert dist[5, 5, 5] > 0 + + +# ============================================================ +# Tests: Upsampling +# ============================================================ +class TestUpsample: + def test_upsample_bool(self): + mask = np.array([[[True, False], [False, True]]]) # shape (1, 2, 2) + result = _upsample_bool(mask, (2, 2, 2), (2, 4, 4)) + assert result.shape == (2, 4, 4) + assert result[0, 0, 0] == True + assert result[0, 0, 2] == False + + def test_upsample_labels(self): + lbl = np.array([[[1, 2], [3, 0]]]) # shape (1, 2, 2) + result = _upsample_labels(lbl, (2, 2, 2), (2, 4, 4)) + assert result.shape == (2, 4, 4) + assert result[0, 0, 0] == 1 + assert result[0, 0, 2] == 2 + + def test_upsample_with_trimming(self): + mask = np.ones((2, 2, 2), dtype=bool) + result = _upsample_bool(mask, (3, 3, 3), (5, 5, 5)) + assert result.shape == (5, 5, 5) + + +# ============================================================ +# Tests: build_kdtrees_by_label +# ============================================================ +class TestBuildKdtreesByLabel: + def test_basic(self): + vol = np.zeros((5, 5, 5), dtype=int) + vol[1, 1, 1] = 1 + vol[3, 3, 3] = 2 + vol[3, 3, 4] = 2 + trees, counts = build_kdtrees_by_label(vol) + assert 1 in trees + assert 2 in trees + assert 0 not in trees + assert counts[1] == 1 + assert counts[2] == 2 + + def test_empty_volume(self): + vol = np.zeros((3, 3, 3), dtype=int) + trees, counts = build_kdtrees_by_label(vol) + assert len(trees) == 0 + assert len(counts) == 0 + + def test_non_zero_background(self): + vol = np.full((5, 5, 5), 99, dtype=int) + vol[2, 2, 2] = 1 + trees, counts = build_kdtrees_by_label(vol, background=99) + assert 1 in trees + assert 99 not in trees + + def test_min_points_filter(self): + vol = np.zeros((5, 5, 5), dtype=int) + vol[1, 1, 1] = 1 # 1 voxel + vol[2:4, 2:4, 2:4] = 2 # 8 voxels + trees, counts = build_kdtrees_by_label(vol, min_points=5) + assert 1 not in trees + assert 2 in trees + + def test_non_3d_raises(self): + vol = np.zeros((5, 5), dtype=int) + with pytest.raises(ValueError, match="3D"): + build_kdtrees_by_label(vol) + + def test_uint64_labels(self): + vol = np.zeros((5, 5, 5), dtype=np.uint64) + vol[1, 1, 1] = np.uint64(2**60) + trees, counts = build_kdtrees_by_label(vol) + assert int(2**60) in trees + + +# ============================================================ +# Tests: pairwise_min_distance_two_sets +# ============================================================ +class TestPairwiseMinDistanceTwoSets: + def _make_tree(self, points): + return cKDTree(np.array(points, dtype=np.float32)) + + def test_basic_exact(self): + tA = self._make_tree([[0, 0, 0]]) + tB = self._make_tree([[3, 4, 0]]) + D = pairwise_min_distance_two_sets([tA], [tB]) + assert D.shape == (1, 1) + assert D[0, 0] == pytest.approx(5.0) + + def test_multiple_trees(self): + tA1 = self._make_tree([[0, 0, 0]]) + tA2 = self._make_tree([[10, 10, 10]]) + tB1 = self._make_tree([[1, 0, 0]]) + D = pairwise_min_distance_two_sets([tA1, tA2], [tB1]) + assert D.shape == (2, 1) + assert D[0, 0] < D[1, 0] + + def test_empty_sets(self): + D = pairwise_min_distance_two_sets([], []) + assert D.shape == (0, 0) + + def test_one_empty(self): + tA = self._make_tree([[0, 0, 0]]) + D = pairwise_min_distance_two_sets([tA], []) + assert D.shape == (1, 0) + + def test_cutoff_mode(self): + tA = self._make_tree([[0, 0, 0]]) + tB = self._make_tree([[100, 100, 100]]) + D = pairwise_min_distance_two_sets([tA], [tB], max_distance=5.0) + assert D[0, 0] == np.inf + + def test_cutoff_mode_within_range(self): + tA = self._make_tree([[0, 0, 0]]) + tB = self._make_tree([[1, 0, 0]]) + D = pairwise_min_distance_two_sets([tA], [tB], max_distance=5.0) + assert D[0, 0] == pytest.approx(1.0) + + def test_multi_point_trees(self): + tA = self._make_tree([[0, 0, 0], [10, 10, 10]]) + tB = self._make_tree([[1, 0, 0], [11, 10, 10]]) + D = pairwise_min_distance_two_sets([tA], [tB]) + assert D.shape == (1, 1) + assert D[0, 0] == pytest.approx(1.0) + + def test_asymmetric_tree_sizes(self): + # tA has many points, tB has few + tA = self._make_tree(np.random.default_rng(0).random((100, 3)) * 10) + tB = self._make_tree([[5, 5, 5]]) + D = pairwise_min_distance_two_sets([tA], [tB]) + assert D.shape == (1, 1) + assert D[0, 0] >= 0 + + +# ============================================================ +# Tests: split_supervoxel_growing +# ============================================================ +class TestSplitSupervoxelGrowing: + def test_basic_split_xyz(self): + """Split a dumbbell into two labels.""" + mask, seeds_a_zyx, seeds_b_zyx = _make_dumbbell_mask(shape=(20, 30, 30)) + # Convert to xyz + mask_xyz = np.transpose(mask, (2, 1, 0)) + seeds_a_xyz = seeds_a_zyx[:, [2, 1, 0]] + seeds_b_xyz = seeds_b_zyx[:, [2, 1, 0]] + + result = split_supervoxel_growing( + mask_xyz, + seeds_a_xyz, + seeds_b_xyz, + voxel_size=(1.0, 1.0, 1.0), + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + enforce_single_cc=True, + raise_if_multi_cc=False, + ) + assert result.shape == mask_xyz.shape + # Should contain labels 1 and 2 + assert np.any(result == 1) + assert np.any(result == 2) + # Labels should only be where mask is True + assert np.all((result > 0) <= mask_xyz) + + def test_basic_split_zyx(self): + """Split using ZYX order.""" + mask, seeds_a, seeds_b = _make_dumbbell_mask(shape=(20, 30, 30)) + result = split_supervoxel_growing( + mask, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + vol_order="zyx", + vox_order="zyx", + seed_order="zyx", + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + enforce_single_cc=True, + raise_if_multi_cc=False, + ) + assert result.shape == mask.shape + assert np.any(result == 1) + assert np.any(result == 2) + + def test_empty_seeds_returns_label1(self): + """With no seeds on one side, the entire mask gets label 1.""" + mask = np.zeros((10, 10, 10), dtype=bool) + mask[3:7, 3:7, 3:7] = True + seeds_a = np.array([[5, 5, 5]]) + seeds_b = np.empty((0, 3), dtype=int) + result = split_supervoxel_growing( + mask, + seeds_a, + seeds_b, + vol_order="zyx", + vox_order="zyx", + seed_order="zyx", + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + ) + assert np.all(result[mask] == 1) + + def test_with_downsample_geodesic(self): + """Test downsampled geodesic computation.""" + mask, seeds_a, seeds_b = _make_dumbbell_mask(shape=(20, 30, 30)) + result = split_supervoxel_growing( + mask, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + vol_order="zyx", + vox_order="zyx", + seed_order="zyx", + downsample_geodesic=(1, 2, 2), + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + enforce_single_cc=True, + raise_if_multi_cc=False, + ) + assert result.shape == mask.shape + assert np.any(result == 1) + assert np.any(result == 2) + + +# ============================================================ +# Tests: connect_both_seeds_via_ridge +# ============================================================ +class TestConnectBothSeedsViaRidge: + def test_basic_connection(self): + mask, seeds_a_zyx, seeds_b_zyx = _make_dumbbell_mask(shape=(20, 30, 30)) + mask_xyz = np.transpose(mask, (2, 1, 0)) + seeds_a_xyz = seeds_a_zyx[:, [2, 1, 0]] + seeds_b_xyz = seeds_b_zyx[:, [2, 1, 0]] + + A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( + mask_xyz, + seeds_a_xyz, + seeds_b_xyz, + voxel_size=(1.0, 1.0, 1.0), + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + downsample=(1, 1, 1), + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + ) + assert okA + assert okB + # Augmented seeds should be at least as many as originals + assert len(A_aug) >= len(seeds_a_xyz) + assert len(B_aug) >= len(seeds_b_xyz) + + def test_single_seed_per_team(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[2:8, 2:8, 2:8] = True + mask_xyz = np.transpose(mask, (2, 1, 0)) + seeds_a = np.array([[4, 4, 4]]) + seeds_b = np.array([[6, 6, 6]]) + + A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( + mask_xyz, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + vol_order="xyz", + seed_order="xyz", + downsample=(1, 1, 1), + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + ) + assert okA + assert okB + + def test_empty_seeds(self): + mask = np.zeros((10, 10, 10), dtype=bool) + mask[2:8, 2:8, 2:8] = True + mask_xyz = np.transpose(mask, (2, 1, 0)) + seeds_a = np.empty((0, 3), dtype=int) + seeds_b = np.array([[4, 4, 4]]) + + A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( + mask_xyz, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + vol_order="xyz", + seed_order="xyz", + downsample=(1, 1, 1), + verbose=False, + snap_kwargs=dict(use_boundary=False, downsample=False), + ) + assert not okA + + +# ============================================================ +# Tests: split_supervoxel_helper +# ============================================================ +class TestSplitSupervoxelHelper: + def test_basic_split(self): + mask, seeds_a_zyx, seeds_b_zyx = _make_dumbbell_mask(shape=(20, 30, 30)) + mask_xyz = np.transpose(mask, (2, 1, 0)) + seeds_a_xyz = seeds_a_zyx[:, [2, 1, 0]] + seeds_b_xyz = seeds_b_zyx[:, [2, 1, 0]] + + result = split_supervoxel_helper( + mask_xyz, + seeds_a_xyz, + seeds_b_xyz, + voxel_size=(1.0, 1.0, 1.0), + verbose=False, + ) + assert result.shape == mask_xyz.shape + assert np.any(result == 1) + assert np.any(result == 2) diff --git a/pychunkedgraph/tests/test_edits_sv.py b/pychunkedgraph/tests/test_edits_sv.py new file mode 100644 index 000000000..861fa8baf --- /dev/null +++ b/pychunkedgraph/tests/test_edits_sv.py @@ -0,0 +1,220 @@ +"""Tests for pychunkedgraph.graph.edits_sv""" + +import numpy as np +import pytest +from collections import defaultdict +from unittest.mock import MagicMock, patch + +from pychunkedgraph.graph.edits_sv import ( + _voxel_crop, + _parse_results, + _get_new_edges, +) +from pychunkedgraph.graph.utils import basetypes + + +# ============================================================ +# Tests: _voxel_crop +# ============================================================ +class TestVoxelCrop: + def test_no_overlap(self): + bbs = np.array([10, 20, 30]) + bbe = np.array([20, 30, 40]) + bbs_ = np.array([10, 20, 30]) + bbe_ = np.array([20, 30, 40]) + crop = _voxel_crop(bbs, bbe, bbs_, bbe_) + # No offset and no clipping + assert crop == np.s_[0:None, 0:None, 0:None] + + def test_with_padding(self): + bbs = np.array([10, 20, 30]) + bbe = np.array([20, 30, 40]) + bbs_ = np.array([9, 19, 29]) + bbe_ = np.array([21, 31, 41]) + crop = _voxel_crop(bbs, bbe, bbs_, bbe_) + # Start offset = bbs - bbs_ = (1, 1, 1) + # End offset: bbe_ - bbe = (1,1,1) != 0, so end = -1 + assert crop == np.s_[1:-1, 1:-1, 1:-1] + + def test_partial_padding(self): + bbs = np.array([10, 20, 30]) + bbe = np.array([20, 30, 40]) + bbs_ = np.array([9, 20, 30]) + bbe_ = np.array([21, 30, 40]) + crop = _voxel_crop(bbs, bbe, bbs_, bbe_) + # Only x has offset + assert crop == np.s_[1:-1, 0:None, 0:None] + + +# ============================================================ +# Tests: _parse_results +# ============================================================ +class TestParseResults: + def test_basic(self): + seg = np.array([[[100, 100], [100, 200]]], dtype=basetypes.NODE_ID) + bbs = np.array([0, 0, 0]) + bbe = np.array([1, 2, 2]) + # result: (indices, old_values, new_values) + indices = np.array([[0, 0, 0], [0, 0, 1]]) + old_values = np.array([100, 100], dtype=basetypes.NODE_ID) + new_values = np.array([300, 301], dtype=basetypes.NODE_ID) + results = [(indices, old_values, new_values)] + + updated_seg, old_new_map, slices = _parse_results(results, seg, bbs, bbe) + assert updated_seg[0, 0, 0] == 300 + assert updated_seg[0, 0, 1] == 301 + assert 300 in old_new_map[100] + assert 301 in old_new_map[100] + + def test_none_result_skipped(self): + seg = np.array([[[100]]], dtype=basetypes.NODE_ID) + bbs = np.array([0, 0, 0]) + bbe = np.array([1, 1, 1]) + results = [None] + updated_seg, old_new_map, slices = _parse_results(results, seg, bbs, bbe) + # No changes + assert updated_seg[0, 0, 0] == 100 + assert len(old_new_map) == 0 + + def test_multiple_results(self): + seg = np.array([[[100, 200]]], dtype=basetypes.NODE_ID) + bbs = np.array([0, 0, 0]) + bbe = np.array([1, 1, 2]) + result1 = ( + np.array([[0, 0, 0]]), + np.array([100], dtype=basetypes.NODE_ID), + np.array([300], dtype=basetypes.NODE_ID), + ) + result2 = ( + np.array([[0, 0, 1]]), + np.array([200], dtype=basetypes.NODE_ID), + np.array([400], dtype=basetypes.NODE_ID), + ) + results = [result1, result2] + + updated_seg, old_new_map, slices = _parse_results(results, seg, bbs, bbe) + assert updated_seg[0, 0, 0] == 300 + assert updated_seg[0, 0, 1] == 400 + assert 300 in old_new_map[100] + assert 400 in old_new_map[200] + + +# ============================================================ +# Tests: _get_new_edges +# ============================================================ +class TestGetNewEdges: + def test_with_active_and_inactive_partners(self): + """Test with both active partners (in sv_ids) and inactive (not in sv_ids).""" + old_sv = np.uint64(10) + new_sv1 = np.uint64(101) + new_sv2 = np.uint64(102) + active_partner = np.uint64(50) # in sv_ids -> active + inactive_partner = np.uint64(99) # not in sv_ids -> inactive + + edges = np.array( + [ + [10, 50], + [10, 99], + ], + dtype=basetypes.NODE_ID, + ) + affinities = np.array([0.9, 0.5], dtype=basetypes.EDGE_AFFINITY) + areas = np.array([100, 200], dtype=basetypes.EDGE_AREA) + + old_new_map = {old_sv: {new_sv1, new_sv2}} + sv_ids = np.array([10, 50, 101, 102], dtype=basetypes.NODE_ID) + + # distance_map: maps each label to its column index in the distance matrix + distance_map = { + np.uint64(10): 0, + np.uint64(50): 1, + np.uint64(101): 2, + np.uint64(102): 3, + } + dist_vec = np.vectorize(distance_map.get) + new_distance_map = {np.uint64(101): 0, np.uint64(102): 1} + new_dist_vec = np.vectorize(new_distance_map.get) + + # Distances: (new_ids x all_ids) + distances = np.array( + [ + [5.0, 3.0, 0.0, 8.0], # new_sv1 + [6.0, 4.0, 8.0, 0.0], # new_sv2 + ] + ) + + result_edges, result_affs, result_areas = _get_new_edges( + (edges, affinities, areas), + sv_ids, + old_new_map, + distances, + dist_vec, + new_dist_vec, + ) + # Should have: + # - Inactive edges: new_sv1->99, new_sv2->99 + # - Active edges: new_ids -> 50 based on distance + # - Fragment edges: new_sv1 <-> new_sv2 + assert len(result_edges) >= 3 + + def test_edge_between_split_fragments(self): + """Split fragments should have edges between them with low affinity.""" + old_sv = np.uint64(10) + new_sv1 = np.uint64(101) + new_sv2 = np.uint64(102) + partner = np.uint64(50) + + edges = np.array([[10, 50]], dtype=basetypes.NODE_ID) + affinities = np.array([0.9], dtype=basetypes.EDGE_AFFINITY) + areas = np.array([100], dtype=basetypes.EDGE_AREA) + + old_new_map = {old_sv: {new_sv1, new_sv2}} + sv_ids = np.array([10, 50, 101, 102], dtype=basetypes.NODE_ID) + + distance_map = { + np.uint64(10): 0, + np.uint64(50): 1, + np.uint64(101): 2, + np.uint64(102): 3, + } + dist_vec = np.vectorize(distance_map.get) + new_distance_map = {np.uint64(101): 0, np.uint64(102): 1} + new_dist_vec = np.vectorize(new_distance_map.get) + distances = np.array( + [ + [5.0, 3.0, 0.0, 8.0], + [6.0, 4.0, 8.0, 0.0], + ] + ) + + result_edges, result_affs, result_areas = _get_new_edges( + (edges, affinities, areas), + sv_ids, + old_new_map, + distances, + dist_vec, + new_dist_vec, + ) + # Check that a fragment-to-fragment edge exists + fragment_edge_found = False + for e in result_edges: + if set(e) == {new_sv1, new_sv2}: + fragment_edge_found = True + break + assert fragment_edge_found + + def test_empty_old_new_map(self): + """Empty old_new_map should return empty results.""" + edges = np.array([[10, 50]], dtype=basetypes.NODE_ID) + affinities = np.array([0.9], dtype=basetypes.EDGE_AFFINITY) + areas = np.array([100], dtype=basetypes.EDGE_AREA) + + result_edges, result_affs, result_areas = _get_new_edges( + (edges, affinities, areas), + np.array([10], dtype=basetypes.NODE_ID), + {}, + np.zeros((0, 0)), + np.vectorize(lambda x: x), + np.vectorize(lambda x: x), + ) + assert len(result_edges) == 0 diff --git a/pychunkedgraph/tests/test_ocdbt.py b/pychunkedgraph/tests/test_ocdbt.py new file mode 100644 index 000000000..a554e23b1 --- /dev/null +++ b/pychunkedgraph/tests/test_ocdbt.py @@ -0,0 +1,141 @@ +"""Tests for pychunkedgraph.graph.ocdbt""" + +import numpy as np +import pytest +from unittest.mock import MagicMock, patch + + +class TestGetSegSourceAndDestinationOcdbt: + @patch("pychunkedgraph.graph.ocdbt.ts") + def test_returns_src_dst_tuple(self, mock_ts): + from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt + + mock_src = MagicMock() + mock_schema = MagicMock() + mock_schema.rank = 4 + mock_schema.dtype = "uint64" + mock_schema.codec = None + mock_schema.domain = None + mock_schema.shape = [256, 256, 256, 1] + mock_schema.chunk_layout = None + mock_schema.dimension_units = None + mock_src.schema = mock_schema + + mock_dst = MagicMock() + + # ts.open returns a future-like with .result() + mock_ts.open.side_effect = [ + MagicMock(result=MagicMock(return_value=mock_src)), + MagicMock(result=MagicMock(return_value=mock_dst)), + ] + + src, dst = get_seg_source_and_destination_ocdbt("gs://bucket/ws") + assert src is mock_src + assert dst is mock_dst + assert mock_ts.open.call_count == 2 + + @patch("pychunkedgraph.graph.ocdbt.ts") + def test_create_flag(self, mock_ts): + from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt + + mock_src = MagicMock() + mock_schema = MagicMock() + mock_schema.rank = 4 + mock_schema.dtype = "uint64" + mock_schema.codec = None + mock_schema.domain = None + mock_schema.shape = [256, 256, 256, 1] + mock_schema.chunk_layout = None + mock_schema.dimension_units = None + mock_src.schema = mock_schema + + mock_dst = MagicMock() + mock_ts.open.side_effect = [ + MagicMock(result=MagicMock(return_value=mock_src)), + MagicMock(result=MagicMock(return_value=mock_dst)), + ] + + src, dst = get_seg_source_and_destination_ocdbt("gs://bucket/ws", create=True) + + # Second ts.open call should have create=True and delete_existing=True + second_call = mock_ts.open.call_args_list[1] + assert second_call.kwargs.get("create") == True + assert second_call.kwargs.get("delete_existing") == True + + +class TestCopyWsChunk: + def test_basic_copy(self): + from pychunkedgraph.graph.ocdbt import copy_ws_chunk + + mock_source = MagicMock() + mock_destination = MagicMock() + + # Simulate source read + data = np.ones((64, 64, 64), dtype=np.uint64) + mock_source.__getitem__ = MagicMock( + return_value=MagicMock( + read=MagicMock( + return_value=MagicMock(result=MagicMock(return_value=data)) + ) + ) + ) + mock_destination.__getitem__ = MagicMock( + return_value=MagicMock( + write=MagicMock( + return_value=MagicMock(result=MagicMock(return_value=None)) + ) + ) + ) + + voxel_bounds = np.array([[0, 256], [0, 256], [0, 256]]) + copy_ws_chunk( + mock_source, + mock_destination, + chunk_size=(64, 64, 64), + coords=[0, 0, 0], + voxel_bounds=voxel_bounds, + ) + # Should have read from source and written to destination + mock_source.__getitem__.assert_called_once() + mock_destination.__getitem__.assert_called_once() + + def test_boundary_clipping(self): + from pychunkedgraph.graph.ocdbt import copy_ws_chunk + + mock_source = MagicMock() + mock_destination = MagicMock() + + data = np.ones((32, 64, 64), dtype=np.uint64) + mock_source.__getitem__ = MagicMock( + return_value=MagicMock( + read=MagicMock( + return_value=MagicMock(result=MagicMock(return_value=data)) + ) + ) + ) + mock_destination.__getitem__ = MagicMock( + return_value=MagicMock( + write=MagicMock( + return_value=MagicMock(result=MagicMock(return_value=None)) + ) + ) + ) + + # Volume ends at 224 in x, so last chunk (192-256) is clipped to (192-224) + voxel_bounds = np.array([[0, 224], [0, 256], [0, 256]]) + copy_ws_chunk( + mock_source, + mock_destination, + chunk_size=(64, 64, 64), + coords=[3, 0, 0], + voxel_bounds=voxel_bounds, + ) + mock_source.__getitem__.assert_called_once() + + +class TestOcdbtConstants: + def test_compression_level(self): + from pychunkedgraph.graph.ocdbt import OCDBT_SEG_COMPRESSION_LEVEL + + assert OCDBT_SEG_COMPRESSION_LEVEL == 17 + assert isinstance(OCDBT_SEG_COMPRESSION_LEVEL, int) From 286912a6f8eeeffd29cfafd71ea31089870acba0 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 4 Mar 2026 18:13:17 +0000 Subject: [PATCH 06/16] use kvdbclient, organize tests --- pychunkedgraph/graph/edits_sv.py | 18 +++++++++++------- .../tests/{ => graph}/test_cutting_sv.py | 0 .../tests/{ => graph}/test_edits_sv.py | 2 +- pychunkedgraph/tests/{ => graph}/test_ocdbt.py | 0 4 files changed, 12 insertions(+), 8 deletions(-) rename pychunkedgraph/tests/{ => graph}/test_cutting_sv.py (100%) rename pychunkedgraph/tests/{ => graph}/test_edits_sv.py (99%) rename pychunkedgraph/tests/{ => graph}/test_ocdbt.py (100%) diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index 4ac3a40f7..3b5395b05 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -23,9 +23,9 @@ from pychunkedgraph.graph.attributes import Hierarchy, OperationLogs from pychunkedgraph.graph.edges import Edges from pychunkedgraph.graph.types import empty_2d -from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.graph import basetypes +from pychunkedgraph.graph import serializers from pychunkedgraph.graph.utils import get_local_segmentation -from pychunkedgraph.graph.utils.serializers import serialize_uint64 from pychunkedgraph.io.edges import get_chunk_edges @@ -286,7 +286,7 @@ def _add_new_edges(cg: ChunkedGraph, edges_tuple: tuple, time_stamp: datetime = val_dict[Connectivity.Area] = areas[mask] rows.append( cg.client.mutate_row( - serialize_uint64(chunk_id, fake_edges=True), + serializers.serialize_uint64(chunk_id, fake_edges=True), val_dict=val_dict, time_stamp=time_stamp, ) @@ -404,19 +404,23 @@ def copy_parents_and_add_lineage( Hierarchy.FormerIdentity: np.array([old_id], dtype=basetypes.NODE_ID), OperationLogs.OperationID: operation_id, } - result.append(cg.client.mutate_row(serialize_uint64(new_id), val_dict)) + result.append( + cg.client.mutate_row(serializers.serialize_uint64(new_id), val_dict) + ) for cell in parent_cells_map[old_id]: cache_utils.update(cg.cache.parents_cache, [new_id], cell.value) parents.add(cell.value) result.append( cg.client.mutate_row( - serialize_uint64(new_id), + serializers.serialize_uint64(new_id), {Hierarchy.Parent: cell.value}, time_stamp=cell.timestamp, ) ) val_dict = {Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID)} - result.append(cg.client.mutate_row(serialize_uint64(old_id), val_dict)) + result.append( + cg.client.mutate_row(serializers.serialize_uint64(old_id), val_dict) + ) children_cells_map = cg.client.read_nodes( node_ids=list(parents), properties=Hierarchy.Child @@ -432,7 +436,7 @@ def copy_parents_and_add_lineage( cg.cache.children_cache[parent] = children result.append( cg.client.mutate_row( - serialize_uint64(parent), + serializers.serialize_uint64(parent), {Hierarchy.Child: children}, time_stamp=cell.timestamp, ) diff --git a/pychunkedgraph/tests/test_cutting_sv.py b/pychunkedgraph/tests/graph/test_cutting_sv.py similarity index 100% rename from pychunkedgraph/tests/test_cutting_sv.py rename to pychunkedgraph/tests/graph/test_cutting_sv.py diff --git a/pychunkedgraph/tests/test_edits_sv.py b/pychunkedgraph/tests/graph/test_edits_sv.py similarity index 99% rename from pychunkedgraph/tests/test_edits_sv.py rename to pychunkedgraph/tests/graph/test_edits_sv.py index 861fa8baf..ced51e68f 100644 --- a/pychunkedgraph/tests/test_edits_sv.py +++ b/pychunkedgraph/tests/graph/test_edits_sv.py @@ -10,7 +10,7 @@ _parse_results, _get_new_edges, ) -from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.graph import basetypes # ============================================================ diff --git a/pychunkedgraph/tests/test_ocdbt.py b/pychunkedgraph/tests/graph/test_ocdbt.py similarity index 100% rename from pychunkedgraph/tests/test_ocdbt.py rename to pychunkedgraph/tests/graph/test_ocdbt.py From 2081e4dbbf211b9fd20cbe1b074ad210b020e310 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 4 Mar 2026 19:02:03 +0000 Subject: [PATCH 07/16] regenrate requirements --- requirements.txt | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 72b58e8be..5df78e9f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -55,6 +55,8 @@ cloud-volume==12.10.0 # via -r requirements.in compressed-segmentation==2.3.2 # via cloud-volume +connected-components-3d==3.26.1 + # via -r requirements.in crc32c==2.8 # via cloud-files croniter==6.0.0 @@ -73,6 +75,8 @@ dracopy==1.7.0 # via # -r requirements.in # cloud-volume +edt==3.1.1 + # via -r requirements.in fasteners==0.20 # via cloud-files fastremap==1.17.7 @@ -165,6 +169,8 @@ grpcio-status==1.78.0 # google-cloud-pubsub idna==3.11 # via requests +imageio==2.37.2 + # via scikit-image importlib-metadata==8.7.1 # via opentelemetry-api inflection==0.5.1 @@ -191,6 +197,8 @@ jsonschema-specifications==2025.9.1 # via jsonschema kvdbclient==0.4.0 # via -r requirements.in +lazy-loader==0.4 + # via scikit-image markdown==3.10.2 # via python-jsonschema-objects markupsafe==3.0.3 @@ -215,13 +223,17 @@ networkx==3.6.1 # -r requirements.in # cloud-volume # osteoid + # scikit-image numpy==2.4.2 # via # -r requirements.in # cloud-files # cloud-volume # compressed-segmentation + # connected-components-3d + # edt # fastremap + # imageio # kvdbclient # messagingclient # microviewer @@ -229,9 +241,12 @@ numpy==2.4.2 # multiwrapper # osteoid # pandas + # scikit-image + # scipy # simplejpeg # task-queue # tensorstore + # tifffile # zmesh opentelemetry-api==1.39.1 # via @@ -251,7 +266,10 @@ orjson==3.11.7 osteoid==0.6.0 # via cloud-volume packaging==26.0 - # via pytest + # via + # lazy-loader + # pytest + # scikit-image pandas==3.0.1 # via -r requirements.in pathos==0.3.5 @@ -261,6 +279,10 @@ pathos==0.3.5 # task-queue pbr==7.0.3 # via task-queue +pillow==12.1.1 + # via + # imageio + # scikit-image pluggy==1.6.0 # via pytest posix-ipc==1.3.2 @@ -352,6 +374,10 @@ rsa==4.9.1 # google-auth s3transfer==0.16.0 # via boto3 +scikit-image==0.26.0 + # via -r requirements.in +scipy==1.17.1 + # via scikit-image simplejpeg==1.9.0 # via cloud-volume six==1.17.0 @@ -372,6 +398,8 @@ tenacity==9.1.4 # task-queue tensorstore==0.1.81 # via -r requirements.in +tifffile==2026.3.3 + # via scikit-image tqdm==4.67.3 # via # cloud-files From c6782aabe172cc5c7a2f94168eff9809b6170988 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 5 Mar 2026 20:56:42 +0000 Subject: [PATCH 08/16] fix: use registered attributes module --- pychunkedgraph/graph/chunkedgraph.py | 4 ++-- pychunkedgraph/graph/edits_sv.py | 30 +++++++++++++++------------- pychunkedgraph/ingest/cli.py | 4 +++- pychunkedgraph/ingest/cluster.py | 12 +++++++++-- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 4dbdcdac9..3a6b1461d 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -672,7 +672,7 @@ def get_subgraph_leaves( self, node_id_or_ids, bbox, bbox_is_coordinate, False, True ) - def get_edited_edges( + def get_edges_from_edits( self, chunk_ids: np.ndarray, time_stamp: datetime.datetime = None ) -> typing.Dict: """ @@ -748,7 +748,7 @@ def get_l2_agglomerations( if self.mock_edges is None: edges_d = self.read_chunk_edges(chunk_ids) - edited_edges = self.get_edited_edges(chunk_ids) + edited_edges = self.get_edges_from_edits(chunk_ids) all_chunk_edges = reduce( lambda x, y: x + y, chain(edges_d.values(), edited_edges.values()), diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index 3b5395b05..15199b403 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -5,22 +5,20 @@ from functools import reduce import logging import multiprocessing as mp -from typing import Callable, Iterable +from typing import Callable from datetime import datetime from collections import defaultdict, deque import fastremap import numpy as np from tqdm import tqdm -from pychunkedgraph.graph import ChunkedGraph, cache as cache_utils -from pychunkedgraph.graph.attributes import Connectivity +from pychunkedgraph.graph import attributes, ChunkedGraph, cache as cache_utils from pychunkedgraph.graph.chunks.utils import chunks_overlapping_bbox, get_neighbors from pychunkedgraph.graph.cutting_sv import ( build_kdtrees_by_label, pairwise_min_distance_two_sets, split_supervoxel_helper, ) -from pychunkedgraph.graph.attributes import Hierarchy, OperationLogs from pychunkedgraph.graph.edges import Edges from pychunkedgraph.graph.types import empty_2d from pychunkedgraph.graph import basetypes @@ -281,9 +279,9 @@ def _add_new_edges(cg: ChunkedGraph, edges_tuple: tuple, time_stamp: datetime = for chunk_id in np.unique(chunks): val_dict = {} mask = chunks_arr[:, 0] == chunk_id - val_dict[Connectivity.SplitEdges] = edges[mask] - val_dict[Connectivity.Affinity] = affinites[mask] - val_dict[Connectivity.Area] = areas[mask] + val_dict[attributes.Connectivity.SplitEdges] = edges[mask] + val_dict[attributes.Connectivity.Affinity] = affinites[mask] + val_dict[attributes.Connectivity.Area] = areas[mask] rows.append( cg.client.mutate_row( serializers.serialize_uint64(chunk_id, fake_edges=True), @@ -396,13 +394,15 @@ def copy_parents_and_add_lineage( parents = set() old_new_map = {k: list(v) for k, v in old_new_map.items()} parent_cells_map = cg.client.read_nodes( - node_ids=list(old_new_map.keys()), properties=Hierarchy.Parent + node_ids=list(old_new_map.keys()), properties=attributes.Hierarchy.Parent ) for old_id, new_ids in old_new_map.items(): for new_id in new_ids: val_dict = { - Hierarchy.FormerIdentity: np.array([old_id], dtype=basetypes.NODE_ID), - OperationLogs.OperationID: operation_id, + attributes.Hierarchy.FormerIdentity: np.array( + [old_id], dtype=basetypes.NODE_ID + ), + attributes.OperationLogs.OperationID: operation_id, } result.append( cg.client.mutate_row(serializers.serialize_uint64(new_id), val_dict) @@ -413,17 +413,19 @@ def copy_parents_and_add_lineage( result.append( cg.client.mutate_row( serializers.serialize_uint64(new_id), - {Hierarchy.Parent: cell.value}, + {attributes.Hierarchy.Parent: cell.value}, time_stamp=cell.timestamp, ) ) - val_dict = {Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID)} + val_dict = { + attributes.Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID) + } result.append( cg.client.mutate_row(serializers.serialize_uint64(old_id), val_dict) ) children_cells_map = cg.client.read_nodes( - node_ids=list(parents), properties=Hierarchy.Child + node_ids=list(parents), properties=attributes.Hierarchy.Child ) for parent, children_cells in children_cells_map.items(): assert len(children_cells) == 1, children_cells @@ -437,7 +439,7 @@ def copy_parents_and_add_lineage( result.append( cg.client.mutate_row( serializers.serialize_uint64(parent), - {Hierarchy.Child: children}, + {attributes.Hierarchy.Child: children}, time_stamp=cell.timestamp, ) ) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index 8d44bf276..94a362d8c 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -5,6 +5,7 @@ """ import logging +import os import click import yaml @@ -70,9 +71,10 @@ def ingest_graph( if not retry: cg.create() + get_seg_source_and_destination_ocdbt(cg.meta.data_source.WATERSHED, create=True) imanager = IngestionManager(ingest_config, meta) enqueue_l2_tasks(imanager, create_atomic_chunk) - get_seg_source_and_destination_ocdbt(cg.meta, create=True) + os._exit(0) @ingest_cli.command("imanager") diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 473a61b22..5514f3b04 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -143,8 +143,16 @@ def create_atomic_chunk(coords: Sequence[int]): for k, v in chunk_edges_active.items(): logging.debug(f"active_{k}: {len(v)}") - src, dst = get_seg_source_and_destination_ocdbt(imanager.cg.meta) - copy_ws_chunk(imanager.cg, coords, src, dst) + src, dst = get_seg_source_and_destination_ocdbt( + imanager.cg.meta.data_source.WATERSHED + ) + copy_ws_chunk( + src, + dst, + imanager.cg.meta.graph_config.CHUNK_SIZE, + coords, + imanager.cg.meta.voxel_bounds, + ) _post_task_completion(imanager, 2, coords) From 8d8226d69137944dabf2e0bcc36a11eb9cf2869e Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 5 Mar 2026 21:11:30 +0000 Subject: [PATCH 09/16] feat(ingest): make ocdbt seg optional --- .gitignore | 1 + pychunkedgraph/ingest/cli.py | 10 ++++++---- pychunkedgraph/ingest/cluster.py | 15 ++++++++------- pychunkedgraph/ingest/manager.py | 2 ++ 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 498253791..044d3b64c 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,7 @@ venv.bak/ # local dev stuff +.claude/ .devcontainer/ *.ipynb *.rdb diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index 94a362d8c..3287c6040 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -47,12 +47,13 @@ def flush_redis(): @ingest_cli.command("graph") @click.argument("graph_id", type=str) @click.argument("dataset", type=click.Path(exists=True)) +@click.option("--ocdbt", is_flag=True, help="Precomputed supervoxel seg into ocdbt.") @click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") -@click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @click.option("--retry", is_flag=True, help="Rerun without creating a new table.") +@click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @job_type_guard(group_name) def ingest_graph( - graph_id: str, dataset: click.Path, raw: bool, test: bool, retry: bool + graph_id: str, dataset: click.Path, ocdbt: bool, raw: bool, retry: bool, test: bool ): """ Main ingest command. @@ -71,8 +72,9 @@ def ingest_graph( if not retry: cg.create() - get_seg_source_and_destination_ocdbt(cg.meta.data_source.WATERSHED, create=True) - imanager = IngestionManager(ingest_config, meta) + if ocdbt: + get_seg_source_and_destination_ocdbt(cg.meta.data_source.WATERSHED, create=True) + imanager = IngestionManager(ingest_config, meta, ocdbt_seg=ocdbt) enqueue_l2_tasks(imanager, create_atomic_chunk) os._exit(0) diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 5514f3b04..72a3c081e 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -146,13 +146,14 @@ def create_atomic_chunk(coords: Sequence[int]): src, dst = get_seg_source_and_destination_ocdbt( imanager.cg.meta.data_source.WATERSHED ) - copy_ws_chunk( - src, - dst, - imanager.cg.meta.graph_config.CHUNK_SIZE, - coords, - imanager.cg.meta.voxel_bounds, - ) + if imanager.ocdbt_seg: + copy_ws_chunk( + src, + dst, + imanager.cg.meta.graph_config.CHUNK_SIZE, + coords, + imanager.cg.meta.voxel_bounds, + ) _post_task_completion(imanager, 2, coords) diff --git a/pychunkedgraph/ingest/manager.py b/pychunkedgraph/ingest/manager.py index c23c3cca4..3ba6e972c 100644 --- a/pychunkedgraph/ingest/manager.py +++ b/pychunkedgraph/ingest/manager.py @@ -15,6 +15,7 @@ def __init__( self, config: IngestConfig, chunkedgraph_meta: ChunkedGraphMeta, + ocdbt_seg: bool = False, _from_pickle: bool = False, ): self._config = config @@ -23,6 +24,7 @@ def __init__( self._redis = None self._task_queues = {} self._from_pickle = _from_pickle + self.ocdbt_seg = ocdbt_seg if not _from_pickle: # initiate redis and store serialized state From a2651b4f67a174235afd0439a169d148689786b1 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 18 Mar 2026 19:23:40 +0000 Subject: [PATCH 10/16] add ocdbt flag to ingest cli --- pychunkedgraph/ingest/cli.py | 2 ++ pychunkedgraph/ingest/cluster.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index 3287c6040..5d448814a 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -73,6 +73,8 @@ def ingest_graph( cg.create() if ocdbt: + cg.meta.custom_data["seg"] = {"ocdbt": True} + cg.update_meta(cg.meta, overwrite=True) get_seg_source_and_destination_ocdbt(cg.meta.data_source.WATERSHED, create=True) imanager = IngestionManager(ingest_config, meta, ocdbt_seg=ocdbt) enqueue_l2_tasks(imanager, create_atomic_chunk) diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 72a3c081e..6233c9d46 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -143,10 +143,10 @@ def create_atomic_chunk(coords: Sequence[int]): for k, v in chunk_edges_active.items(): logging.debug(f"active_{k}: {len(v)}") - src, dst = get_seg_source_and_destination_ocdbt( - imanager.cg.meta.data_source.WATERSHED - ) if imanager.ocdbt_seg: + src, dst = get_seg_source_and_destination_ocdbt( + imanager.cg.meta.data_source.WATERSHED + ) copy_ws_chunk( src, dst, From cc6784fa2b74093966c49eaa76d8ae6f80ff5c45 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 18 Mar 2026 19:26:04 +0000 Subject: [PATCH 11/16] bugfix: use labels for cx edges after split with inf affinity --- pychunkedgraph/app/segmentation/common.py | 10 +- pychunkedgraph/graph/edits_sv.py | 105 ++++++++++++--- pychunkedgraph/graph/operation.py | 1 + pychunkedgraph/tests/graph/test_edits_sv.py | 137 +++++++++++++++++++- 4 files changed, 220 insertions(+), 33 deletions(-) diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 0a9c1789f..1f3a6ae06 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -384,7 +384,6 @@ def handle_merge(table_id, allow_same_segment_merge=False): source_coords=coords[:1], sink_coords=coords[1:], allow_same_segment_merge=allow_same_segment_merge, - do_sanity_check=True, ) except cg_exceptions.LockingError as e: @@ -410,7 +409,6 @@ def handle_merge(table_id, allow_same_segment_merge=False): def _get_sources_and_sinks(cg: ChunkedGraph, data): - current_app.logger.debug(data) node_idents = [] node_ident_map = {"sources": 0, "sinks": 1} coords = [] @@ -426,13 +424,7 @@ def _get_sources_and_sinks(cg: ChunkedGraph, data): coords = np.array(coords) node_idents = np.array(node_idents) - start = time.time() sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) - current_app.logger.info(f"SV lookup took {time.time() - start}s.") - current_app.logger.debug( - {"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents} - ) - source_ids = sv_ids[node_idents == 0] sink_ids = sv_ids[node_idents == 1] source_coords = coords[node_idents == 0] @@ -450,6 +442,7 @@ def handle_split(table_id): mincut = request.args.get("mincut", True, type=str2bool) cg = app_utils.get_cg(table_id, skip_cache=True) + current_app.logger.debug(data) sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) try: ret = cg.remove_edges( @@ -494,7 +487,6 @@ def handle_split(table_id): source_coords=source_coords, sink_coords=sink_coords, mincut=mincut, - do_sanity_check=True, ) except cg_exceptions.LockingError as e: raise cg_exceptions.InternalServerError(e) diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index 15199b403..287b44963 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -95,6 +95,7 @@ def _update_chunk(args): _indices = [] _old_values = [] _new_values = [] + _label_id_map = {} for _id in labels: _mask = chunk_seg == _id if np.any(_mask): @@ -104,12 +105,14 @@ def _update_chunk(args): _indices.append(_index) _ones = np.ones(len(_index), dtype=basetypes.NODE_ID) _old_values.append(_ones * _og_value) - _new_values.append(_ones * cg.id_client.create_node_id(chunk_id)) + new_id = cg.id_client.create_node_id(chunk_id) + _new_values.append(_ones * new_id) + _label_id_map[int(_id)] = new_id _indices = np.concatenate(_indices) + (chunk_bbox[0] - bb_start) _old_values = np.concatenate(_old_values) _new_values = np.concatenate(_new_values) - return (_indices, _old_values, _new_values) + return (_indices, _old_values, _new_values, _label_id_map) def _voxel_crop(bbs, bbe, bbs_, bbe_): @@ -122,17 +125,62 @@ def _voxel_crop(bbs, bbe, bbs_, bbe_): def _parse_results(results, seg, bbs, bbe): old_new_map = defaultdict(set) + new_id_label_map = {} for result in results: if result: - indexer, old_values, new_values = result + indexer, old_values, new_values, label_id_map = result seg[tuple(indexer.T)] = new_values for old_sv, new_sv in zip(old_values, new_values): old_new_map[old_sv].add(new_sv) + for label, new_id in label_id_map.items(): + new_id_label_map[new_id] = label assert np.all(seg.shape == bbe - bbs), f"{seg.shape} != {bbe - bbs}" slices = tuple(slice(start, end) for start, end in zip(bbs, bbe)) + (slice(None),) logging.info(f"slices {slices}") - return seg, old_new_map, slices + return seg, old_new_map, slices, new_id_label_map + + +def _match_by_label(new_ids, partner, aff, area, new_id_label_map, distances_row): + """For inf-affinity (cross-chunk) edges: connect fragments with matching split label.""" + partner_label = new_id_label_map[partner] + matching = np.array( + [nid for nid in new_ids if new_id_label_map.get(nid) == partner_label], + dtype=basetypes.NODE_ID, + ) + if len(matching): + edges = np.column_stack( + [matching, np.full(len(matching), partner, dtype=np.uint64)] + ) + affs = np.full(len(matching), aff, dtype=basetypes.EDGE_AFFINITY) + areas = np.full(len(matching), area, dtype=basetypes.EDGE_AREA) + return edges, affs, areas + # fallback: closest fragment + close = new_ids[np.argmin(distances_row)] + return ( + np.array([[close, partner]], dtype=np.uint64), + np.array([aff], dtype=basetypes.EDGE_AFFINITY), + np.array([area], dtype=basetypes.EDGE_AREA), + ) + + +def _match_by_proximity(new_ids, partner, aff, area, distances_row, threshold): + """For regular edges: connect fragments within distance threshold.""" + close_mask = distances_row < threshold + nearby = new_ids[close_mask] + if len(nearby): + edges = np.column_stack( + [nearby, np.full(len(nearby), partner, dtype=np.uint64)] + ) + affs = np.full(len(nearby), aff, dtype=basetypes.EDGE_AFFINITY) + areas = np.full(len(nearby), area, dtype=basetypes.EDGE_AREA) + return edges, affs, areas + close = new_ids[np.argmin(distances_row)] + return ( + np.array([[close, partner]], dtype=np.uint64), + np.array([aff], dtype=basetypes.EDGE_AFFINITY), + np.array([area], dtype=basetypes.EDGE_AREA), + ) def _get_new_edges( @@ -142,6 +190,7 @@ def _get_new_edges( distances: np.ndarray, dist_vec: Callable, new_dist_vec: Callable, + new_id_label_map: dict = None, ): THRESHOLD = 10 new_edges, new_affs, new_areas = [], [], [] @@ -189,19 +238,29 @@ def _get_new_edges( logging.info(f"new_dist_vec(new_ids): {new_dist_vec(new_ids)}") logging.info(f"dist_vec(active_partners): {dist_vec(active_partners)}") distances_ = distances[new_dist_vec(new_ids)][:, dist_vec(active_partners)].T - for i, _ in enumerate(active_partners): - new_ids_ = new_ids[distances_[i] < THRESHOLD] - if len(new_ids_): - _a = [new_ids_, [active_partners[i]] * len(new_ids_)] - new_edges.extend(np.array(_a, dtype=np.uint64).T) - new_affs.extend([active_affs[i]] * len(new_ids_)) - new_areas.extend([active_areas[i]] * len(new_ids_)) + for i, partner in enumerate(active_partners): + aff = active_affs[i] + if np.isinf(aff) and new_id_label_map and partner in new_id_label_map: + e, a, ar = _match_by_label( + new_ids, + partner, + aff, + active_areas[i], + new_id_label_map, + distances_[i], + ) else: - close_new_sv_id = new_ids[np.argmin(distances_[i])] - _a = [close_new_sv_id, active_partners[i]] - new_edges.append(np.array(_a, dtype=np.uint64)) - new_affs.append(active_affs[i]) - new_areas.append(active_areas[i]) + e, a, ar = _match_by_proximity( + new_ids, + partner, + aff, + active_areas[i], + distances_[i], + THRESHOLD, + ) + new_edges.extend(e) + new_affs.extend(a) + new_areas.extend(ar) # edges between split fragments for i in range(len(new_ids)): @@ -225,6 +284,7 @@ def _update_edges( bbox: np.ndarray, new_seg: np.ndarray, old_new_map: dict, + new_id_label_map: dict = None, ): old_new_map = dict(old_new_map) kdtrees, _ = build_kdtrees_by_label(new_seg) @@ -259,6 +319,7 @@ def _update_edges( distances, dist_vec, new_dist_vec, + new_id_label_map, ) @@ -350,7 +411,9 @@ def split_supervoxel( with mp.Pool() as pool: results = [*tqdm(pool.imap_unordered(_update_chunk, tasks), total=len(tasks))] seg_cropped = seg[voxel_overlap_crop].copy() - new_seg, old_new_map, slices = _parse_results(results, seg_cropped, bbs, bbe) + new_seg, old_new_map, slices, new_id_label_map = _parse_results( + results, seg_cropped, bbs, bbe + ) seg_roots = seg.copy() sv_ids = fastremap.unique(seg) @@ -366,7 +429,13 @@ def split_supervoxel( seg_masked[voxel_overlap_crop] = new_seg edges_tuple = _update_edges( - cg, sv_ids, root, np.array([bbs, bbe]), seg_masked, old_new_map + cg, + sv_ids, + root, + np.array([bbs, bbe]), + seg_masked, + old_new_map, + new_id_label_map, ) rows0 = copy_parents_and_add_lineage(cg, operation_id, old_new_map) diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 5bf221e01..2066bdba0 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -461,6 +461,7 @@ def execute( old_root_ids=root_ids, ) except SupervoxelSplitRequiredError as err: + # no need for self.cg.cache = None, the cache must be retained after sv split raise SupervoxelSplitRequiredError( str(err), err.sv_remapping, operation_id=lock.operation_id ) from err diff --git a/pychunkedgraph/tests/graph/test_edits_sv.py b/pychunkedgraph/tests/graph/test_edits_sv.py index ced51e68f..c0b0f7d73 100644 --- a/pychunkedgraph/tests/graph/test_edits_sv.py +++ b/pychunkedgraph/tests/graph/test_edits_sv.py @@ -9,6 +9,8 @@ _voxel_crop, _parse_results, _get_new_edges, + _match_by_label, + _match_by_proximity, ) from pychunkedgraph.graph import basetypes @@ -54,27 +56,33 @@ def test_basic(self): seg = np.array([[[100, 100], [100, 200]]], dtype=basetypes.NODE_ID) bbs = np.array([0, 0, 0]) bbe = np.array([1, 2, 2]) - # result: (indices, old_values, new_values) indices = np.array([[0, 0, 0], [0, 0, 1]]) old_values = np.array([100, 100], dtype=basetypes.NODE_ID) new_values = np.array([300, 301], dtype=basetypes.NODE_ID) - results = [(indices, old_values, new_values)] + label_id_map = {1: np.uint64(300), 2: np.uint64(301)} + results = [(indices, old_values, new_values, label_id_map)] - updated_seg, old_new_map, slices = _parse_results(results, seg, bbs, bbe) + updated_seg, old_new_map, slices, new_id_label_map = _parse_results( + results, seg, bbs, bbe + ) assert updated_seg[0, 0, 0] == 300 assert updated_seg[0, 0, 1] == 301 assert 300 in old_new_map[100] assert 301 in old_new_map[100] + assert new_id_label_map[np.uint64(300)] == 1 + assert new_id_label_map[np.uint64(301)] == 2 def test_none_result_skipped(self): seg = np.array([[[100]]], dtype=basetypes.NODE_ID) bbs = np.array([0, 0, 0]) bbe = np.array([1, 1, 1]) results = [None] - updated_seg, old_new_map, slices = _parse_results(results, seg, bbs, bbe) - # No changes + updated_seg, old_new_map, slices, new_id_label_map = _parse_results( + results, seg, bbs, bbe + ) assert updated_seg[0, 0, 0] == 100 assert len(old_new_map) == 0 + assert len(new_id_label_map) == 0 def test_multiple_results(self): seg = np.array([[[100, 200]]], dtype=basetypes.NODE_ID) @@ -84,15 +92,19 @@ def test_multiple_results(self): np.array([[0, 0, 0]]), np.array([100], dtype=basetypes.NODE_ID), np.array([300], dtype=basetypes.NODE_ID), + {1: np.uint64(300)}, ) result2 = ( np.array([[0, 0, 1]]), np.array([200], dtype=basetypes.NODE_ID), np.array([400], dtype=basetypes.NODE_ID), + {1: np.uint64(400)}, ) results = [result1, result2] - updated_seg, old_new_map, slices = _parse_results(results, seg, bbs, bbe) + updated_seg, old_new_map, slices, new_id_label_map = _parse_results( + results, seg, bbs, bbe + ) assert updated_seg[0, 0, 0] == 300 assert updated_seg[0, 0, 1] == 400 assert 300 in old_new_map[100] @@ -218,3 +230,116 @@ def test_empty_old_new_map(self): np.vectorize(lambda x: x), ) assert len(result_edges) == 0 + + def test_inf_affinity_uses_label_matching(self): + """Inf-affinity (cross-chunk) edges should connect only same-label fragments.""" + old_sv = np.uint64(10) + new_sv1 = np.uint64(101) # label 1 + new_sv2 = np.uint64(102) # label 2 + # partner is a cross-chunk fragment also from the split, label 1 + partner = np.uint64(201) + + edges = np.array([[10, 201]], dtype=basetypes.NODE_ID) + affinities = np.array([np.inf], dtype=basetypes.EDGE_AFFINITY) + areas = np.array([0], dtype=basetypes.EDGE_AREA) + + old_new_map = {old_sv: {new_sv1, new_sv2}} + sv_ids = np.array([10, 101, 102, 201], dtype=basetypes.NODE_ID) + + distance_map = { + np.uint64(10): 0, + np.uint64(101): 1, + np.uint64(102): 2, + np.uint64(201): 3, + } + dist_vec = np.vectorize(distance_map.get) + new_distance_map = {np.uint64(101): 0, np.uint64(102): 1} + new_dist_vec = np.vectorize(new_distance_map.get) + + # new_sv2 (label 2) is closer to partner 201, but label doesn't match + distances = np.array( + [ + [5.0, 0.0, 8.0, 9.0], # new_sv1 (label 1) — far from partner + [6.0, 8.0, 0.0, 2.0], # new_sv2 (label 2) — close to partner + ] + ) + + new_id_label_map = { + np.uint64(101): 1, + np.uint64(102): 2, + np.uint64(201): 1, # same label as new_sv1 + } + + result_edges, result_affs, result_areas = _get_new_edges( + (edges, affinities, areas), + sv_ids, + old_new_map, + distances, + dist_vec, + new_dist_vec, + new_id_label_map, + ) + + # The inf-affinity edge should connect new_sv1 (label 1) to partner 201 (label 1) + # NOT new_sv2 (label 2) even though it's closer + inf_edges = result_edges[np.isinf(result_affs)] + for e in inf_edges: + assert ( + new_sv2 not in e + ), f"Inf-affinity edge {e} should not connect label-2 fragment to label-1 partner" + # Verify new_sv1 <-> 201 inf edge exists + found = any(set(e) == {new_sv1, partner} for e in inf_edges) + assert found, "Expected inf-affinity edge between same-label fragments" + + +# ============================================================ +# Tests: _match_by_label / _match_by_proximity +# ============================================================ +class TestMatchByLabel: + def test_matching_label(self): + new_ids = np.array([101, 102], dtype=basetypes.NODE_ID) + new_id_label_map = {np.uint64(101): 1, np.uint64(102): 2, np.uint64(201): 1} + distances_row = np.array([9.0, 2.0]) # 102 is closer + + edges, affs, areas = _match_by_label( + new_ids, np.uint64(201), np.inf, 0, new_id_label_map, distances_row + ) + # Should pick 101 (label 1) not 102 (label 2, closer) + assert all(np.uint64(101) in e for e in edges) + assert np.uint64(102) not in edges.flatten() + + def test_fallback_to_closest(self): + new_ids = np.array([101, 102], dtype=basetypes.NODE_ID) + # partner label 3 doesn't match any new_id + new_id_label_map = {np.uint64(101): 1, np.uint64(102): 2, np.uint64(201): 3} + distances_row = np.array([9.0, 2.0]) + + edges, affs, areas = _match_by_label( + new_ids, np.uint64(201), np.inf, 0, new_id_label_map, distances_row + ) + # Fallback: closest = 102 + assert np.uint64(102) in edges.flatten() + + +class TestMatchByProximity: + def test_within_threshold(self): + new_ids = np.array([101, 102], dtype=basetypes.NODE_ID) + distances_row = np.array([3.0, 15.0]) + + edges, affs, areas = _match_by_proximity( + new_ids, np.uint64(50), 0.9, 100, distances_row, threshold=10 + ) + # Only 101 is within threshold + assert len(edges) == 1 + assert np.uint64(101) in edges[0] + + def test_fallback_to_closest(self): + new_ids = np.array([101, 102], dtype=basetypes.NODE_ID) + distances_row = np.array([15.0, 20.0]) # both outside threshold + + edges, affs, areas = _match_by_proximity( + new_ids, np.uint64(50), 0.9, 100, distances_row, threshold=10 + ) + # Fallback: closest = 101 + assert len(edges) == 1 + assert np.uint64(101) in edges[0] From 185293d590e6863569723fb3569077ea33813155 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 19 Mar 2026 01:03:19 +0000 Subject: [PATCH 12/16] optimize whole sv lookup, use rep sv from source --- pychunkedgraph/app/segmentation/common.py | 2 +- pychunkedgraph/graph/edits_sv.py | 69 ++++++++++---------- pychunkedgraph/graph/meta.py | 5 +- pychunkedgraph/graph/ocdbt.py | 2 +- pychunkedgraph/graph/operation.py | 6 -- pychunkedgraph/ingest/cli.py | 19 +++++- pychunkedgraph/tests/graph/test_ocdbt.py | 8 --- pychunkedgraph/tests/graph/test_operation.py | 15 ----- workers/mesh_worker.py | 4 +- 9 files changed, 60 insertions(+), 70 deletions(-) diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 1f3a6ae06..da037fc34 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -473,7 +473,7 @@ def handle_split(table_id): _mask1 = sinks_remapped == sv_to_split split_supervoxel( cg, - sv_to_split, + sources[_mask0][0], source_coords[_mask0], sink_coords[_mask1], e.operation_id, diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index 287b44963..dc5641ca9 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -30,29 +30,23 @@ def _get_whole_sv( cg: ChunkedGraph, node: basetypes.NODE_ID, min_coord, max_coord ) -> set: - cx_edges = [empty_2d] - explored_chunks = set() + all_chunks = [ + (x, y, z) + for x in range(min_coord[0], max_coord[0]) + for y in range(min_coord[1], max_coord[1]) + for z in range(min_coord[2], max_coord[2]) + ] + edges = get_chunk_edges(cg.meta.data_source.EDGES, all_chunks) + cx_edges = edges["cross"].get_pairs() + if len(cx_edges) == 0: + return {node} + explored_nodes = set([node]) queue = deque([node]) - - while len(queue) > 0: + while queue: vertex = queue.popleft() - chunk = cg.get_chunk_coordinates(vertex) - chunks = get_neighbors(chunk, min_coord=min_coord, max_coord=max_coord) - - unexplored_chunks = [] - for _chunk in chunks: - if tuple(_chunk) not in explored_chunks: - unexplored_chunks.append(tuple(_chunk)) - - edges = get_chunk_edges(cg.meta.data_source.EDGES, unexplored_chunks) - explored_chunks.update(unexplored_chunks) - _cx_edges = edges["cross"].get_pairs() - cx_edges.append(_cx_edges) - _cx_edges = np.concatenate(cx_edges) - - mask = _cx_edges[:, 0] == vertex - neighbors = _cx_edges[mask][:, 1] + mask = cx_edges[:, 0] == vertex + neighbors = cx_edges[mask][:, 1] if len(neighbors) > 0: neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) @@ -61,10 +55,9 @@ def _get_whole_sv( neighbors = neighbors[min_mask & max_mask] for neighbor in neighbors: - if neighbor in explored_nodes: - continue - explored_nodes.add(neighbor) - queue.append(neighbor) + if neighbor not in explored_nodes: + explored_nodes.add(neighbor) + queue.append(neighbor) return explored_nodes @@ -191,8 +184,8 @@ def _get_new_edges( dist_vec: Callable, new_dist_vec: Callable, new_id_label_map: dict = None, + threshold: int = 10, ): - THRESHOLD = 10 new_edges, new_affs, new_areas = [], [], [] edges, affinities, areas = edges_info @@ -256,7 +249,7 @@ def _get_new_edges( aff, active_areas[i], distances_[i], - THRESHOLD, + threshold, ) new_edges.extend(e) new_affs.extend(a) @@ -270,9 +263,15 @@ def _get_new_edges( new_affs.append(0.001) new_areas.append(0) + if len(new_edges) == 0: + return ( + np.array([], dtype=basetypes.NODE_ID), + np.array([], dtype=basetypes.EDGE_AFFINITY), + np.array([], dtype=basetypes.EDGE_AREA), + ) affinites = np.array(new_affs, dtype=basetypes.EDGE_AFFINITY) areas = np.array(new_areas, dtype=basetypes.EDGE_AREA) - edges = np.array(new_edges, dtype=basetypes.NODE_ID) + edges = np.sort(np.array(new_edges, dtype=basetypes.NODE_ID), axis=1) edges, idx = np.unique(edges, return_index=True, axis=0) return edges, affinites[idx], areas[idx] @@ -320,6 +319,7 @@ def _update_edges( dist_vec, new_dist_vec, new_id_label_map, + threshold=cg.meta.sv_split_threshold, ) @@ -372,18 +372,21 @@ def split_supervoxel( vol_end = cg.meta.voxel_bounds[:, 1] chunk_size = cg.meta.graph_config.CHUNK_SIZE _coords = np.concatenate([source_coords, sink_coords]) - _padding = np.array([64] * 3) / cg.meta.resolution + _padding = np.array([cg.meta.resolution[-1] * 2] * 3) / cg.meta.resolution bbs = np.clip((np.min(_coords, 0) - _padding).astype(int), vol_start, vol_end) bbe = np.clip((np.max(_coords, 0) + _padding).astype(int), vol_start, vol_end) chunk_min, chunk_max = bbs // chunk_size, np.ceil(bbe / chunk_size).astype(int) bbs, bbe = chunk_min * chunk_size, chunk_max * chunk_size - logging.info(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}") - logging.info(f"{chunk_size}; {_padding}; {(bbs, bbe)}; {(chunk_min, chunk_max)}") + logging.info( + f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}; res {cg.meta.resolution}" + ) + logging.info(f"chunk and padding {chunk_size}; {_padding}") + logging.info(f"bbox and chunk min max {(bbs, bbe)}; {(chunk_min, chunk_max)}") cut_supervoxels = _get_whole_sv(cg, sv_id, min_coord=chunk_min, max_coord=chunk_max) supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) - logging.info(f"{sv_id} -> {cut_supervoxels}") + logging.info(f"whole sv {sv_id} -> {cut_supervoxels}") # one voxel overlap for neighbors bbs_ = np.clip(bbs - 1, vol_start, vol_end) @@ -421,7 +424,7 @@ def split_supervoxel( seg_roots = fastremap.remap(seg_roots, dict(zip(sv_ids, roots)), in_place=True) root = cg.get_root(sv_id) - logging.info(f"root {root}") + logging.info(f"{sv_id} root = {root}") seg_masked = seg.copy() seg_masked[seg_roots != root] = 0 @@ -443,8 +446,8 @@ def split_supervoxel( rows = rows0 + rows1 logging.info(f"{operation_id}: writing {len(rows)} new rows") - cg.client.write(rows) cg.meta.ws_ocdbt[slices] = new_seg[..., np.newaxis] + cg.client.write(rows) return old_new_map, edges_tuple diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py index 6a938f802..40968c697 100644 --- a/pychunkedgraph/graph/meta.py +++ b/pychunkedgraph/graph/meta.py @@ -14,7 +14,6 @@ from .chunks.utils import get_chunks_boundary from ..utils.redis import get_redis_connection - _datasource_fields = ("EDGES", "COMPONENTS", "WATERSHED", "DATA_VERSION", "CV_MIP") _datasource_defaults = (None, None, None, None, 0) DataSource = namedtuple( @@ -244,6 +243,10 @@ def edge_dtype(self): def READ_ONLY(self): return self.custom_data.get("READ_ONLY", False) + @property + def sv_split_threshold(self) -> int: + return self._custom_data.get("seg", {}).get("sv_split_threshold", 10) + @property def split_bounding_offset(self): return self.custom_data.get( diff --git a/pychunkedgraph/graph/ocdbt.py b/pychunkedgraph/graph/ocdbt.py index 03c6d9b65..f715bc101 100644 --- a/pychunkedgraph/graph/ocdbt.py +++ b/pychunkedgraph/graph/ocdbt.py @@ -2,7 +2,7 @@ import numpy as np import tensorstore as ts -OCDBT_SEG_COMPRESSION_LEVEL = 17 +OCDBT_SEG_COMPRESSION_LEVEL = 12 def get_seg_source_and_destination_ocdbt(ws_path: str, create: bool = False) -> tuple: diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 2066bdba0..0d91e3990 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -893,12 +893,6 @@ def __init__( self.path_augment = path_augment self.disallow_isolating_cut = disallow_isolating_cut self.do_sanity_check = do_sanity_check - if np.any(np.isin(self.sink_ids, self.source_ids)): - raise SupervoxelSplitRequiredError( - "Supervoxels exist in both sink and source, " - "try placing the points further apart.", - None, - ) ids = np.concatenate([self.source_ids, self.sink_ids]).astype(basetypes.NODE_ID) layers = self.cg.get_chunk_layers(ids) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index 5d448814a..ca958c354 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -48,12 +48,24 @@ def flush_redis(): @click.argument("graph_id", type=str) @click.argument("dataset", type=click.Path(exists=True)) @click.option("--ocdbt", is_flag=True, help="Precomputed supervoxel seg into ocdbt.") +@click.option( + "--sv-split-threshold", + type=int, + default=10, + help="Distance threshold for SV split edge matching.", +) @click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") @click.option("--retry", is_flag=True, help="Rerun without creating a new table.") @click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @job_type_guard(group_name) def ingest_graph( - graph_id: str, dataset: click.Path, ocdbt: bool, raw: bool, retry: bool, test: bool + graph_id: str, + dataset: click.Path, + ocdbt: bool, + sv_split_threshold: int, + raw: bool, + retry: bool, + test: bool, ): """ Main ingest command. @@ -73,7 +85,10 @@ def ingest_graph( cg.create() if ocdbt: - cg.meta.custom_data["seg"] = {"ocdbt": True} + cg.meta.custom_data["seg"] = { + "ocdbt": True, + "sv_split_threshold": sv_split_threshold, + } cg.update_meta(cg.meta, overwrite=True) get_seg_source_and_destination_ocdbt(cg.meta.data_source.WATERSHED, create=True) imanager = IngestionManager(ingest_config, meta, ocdbt_seg=ocdbt) diff --git a/pychunkedgraph/tests/graph/test_ocdbt.py b/pychunkedgraph/tests/graph/test_ocdbt.py index a554e23b1..9f31e5f6f 100644 --- a/pychunkedgraph/tests/graph/test_ocdbt.py +++ b/pychunkedgraph/tests/graph/test_ocdbt.py @@ -131,11 +131,3 @@ def test_boundary_clipping(self): voxel_bounds=voxel_bounds, ) mock_source.__getitem__.assert_called_once() - - -class TestOcdbtConstants: - def test_compression_level(self): - from pychunkedgraph.graph.ocdbt import OCDBT_SEG_COMPRESSION_LEVEL - - assert OCDBT_SEG_COMPRESSION_LEVEL == 17 - assert isinstance(OCDBT_SEG_COMPRESSION_LEVEL, int) diff --git a/pychunkedgraph/tests/graph/test_operation.py b/pychunkedgraph/tests/graph/test_operation.py index 328ceb425..fa916ae0d 100644 --- a/pychunkedgraph/tests/graph/test_operation.py +++ b/pychunkedgraph/tests/graph/test_operation.py @@ -498,21 +498,6 @@ def test_split_self_loop_raises(self, gen_graph): sink_coords=None, ) - @pytest.mark.timeout(30) - def test_multicut_overlapping_ids_raises(self, gen_graph): - """source_ids overlapping sink_ids should raise PreconditionError (line 872).""" - cg, _, sv0, sv1 = _build_cross_chunk(gen_graph) - with pytest.raises(SupervoxelSplitRequiredError, match="both sink and source"): - MulticutOperation( - cg, - user_id="test_user", - source_ids=[sv0, sv1], - sink_ids=[sv1], - source_coords=[[0, 0, 0], [1, 0, 0]], - sink_coords=[[1, 0, 0]], - bbox_offset=[240, 240, 24], - ) - # =========================================================================== # NEW: Empty coords / affinities normalization (lines 82, 86, 593) diff --git a/workers/mesh_worker.py b/workers/mesh_worker.py index b8f1e0024..cb81b687d 100644 --- a/workers/mesh_worker.py +++ b/workers/mesh_worker.py @@ -10,10 +10,9 @@ from messagingclient import MessagingClient from pychunkedgraph.graph import ChunkedGraph -from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.graph import basetypes from pychunkedgraph.meshing import meshgen - PCG_CACHE = {} @@ -55,7 +54,6 @@ def callback(payload): cg.meta.data_source.WATERSHED, mesh_dir, cv_unsharded_mesh_dir ) - logging.log(INFO_HIGH, f"remeshing {l2ids}; graph {table_id} operation {op_id}.") meshgen.remeshing( cg, From 1802ced6de9165ba248196dadc0fa80606d6ea17 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 19 Mar 2026 19:21:27 +0000 Subject: [PATCH 13/16] fix: use relative ocdbt path in info --- pychunkedgraph/graph/meta.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py index 40968c697..23ff8d35d 100644 --- a/pychunkedgraph/graph/meta.py +++ b/pychunkedgraph/graph/meta.py @@ -67,6 +67,7 @@ def __init__( self._layer_count = None self._bitmasks = None self._ocdbt_seg = None + self._ocdbt_path = None @property def graph_config(self): @@ -108,6 +109,14 @@ def ocdbt_seg(self) -> bool: self._ocdbt_seg = self._custom_data.get("seg", {}).get("ocdbt", False) return self._ocdbt_seg + @property + def ocdbt_path(self) -> bool: + if self._ocdbt_path is None: + self._ocdbt_path = self._custom_data.get("seg", {}).get( + "ocdbt_path", "ocdbt/base" + ) + return self._ocdbt_path + @property def ws_ocdbt(self): assert self.ocdbt_seg, "make sure this pcg has segmentation in ocdbt format" @@ -260,11 +269,7 @@ def dataset_info(self) -> Dict: info.update( { "chunks_start_at_voxel_offset": True, - "data_dir": ( - self.ws_ocdbt.kvstore.base.url - if self.ocdbt_seg - else self.data_source.WATERSHED - ), + "data_dir": self.data_source.WATERSHED, "graph": { "chunk_size": self.graph_config.CHUNK_SIZE, "bounding_box": [2048, 2048, 512], @@ -272,6 +277,8 @@ def dataset_info(self) -> Dict: "cv_mip": self.data_source.CV_MIP, "n_layers": self.layer_count, "spatial_bit_masks": self.bitmasks, + "ocdbt_seg": self.ocdbt_seg, + "ocdbt_path": self.ocdbt_path, }, } ) From 3d1e97ae53ff2aa80952de600a7eebf421987cc8 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Fri, 20 Mar 2026 16:43:07 +0000 Subject: [PATCH 14/16] fix: inf edges to only closest partner; mesh worker pubsub version, always sv lookup from seg for ocdbt --- pychunkedgraph/app/__init__.py | 4 + pychunkedgraph/app/app_utils.py | 5 + pychunkedgraph/app/segmentation/common.py | 4 + pychunkedgraph/graph/chunkedgraph.py | 13 +- pychunkedgraph/graph/cutting_sv.py | 5 +- pychunkedgraph/graph/edits_sv.py | 183 +++++++++--------- pychunkedgraph/graph/utils/generic.py | 10 +- pychunkedgraph/graph/utils/id_helpers.py | 7 +- pychunkedgraph/ingest/upgrade/atomic_layer.py | 6 + .../tests/graph/test_utils_id_helpers.py | 1 + requirements.in | 2 +- requirements.txt | 2 +- uwsgi.ini | 26 +-- 13 files changed, 135 insertions(+), 133 deletions(-) diff --git a/pychunkedgraph/app/__init__.py b/pychunkedgraph/app/__init__.py index 262849258..042fa7ff1 100644 --- a/pychunkedgraph/app/__init__.py +++ b/pychunkedgraph/app/__init__.py @@ -99,6 +99,10 @@ def configure_app(app): app.logger.setLevel(app.config["LOGGING_LEVEL"]) app.logger.propagate = False + # Also configure root logger so logging.info() calls in library code are captured + logging.root.addHandler(handler) + logging.root.setLevel(logging.INFO) + if app.config["USE_REDIS_JOBS"]: app.redis = redis.Redis.from_url(app.config["REDIS_URL"]) app.test_q = Queue("test", connection=app.redis) diff --git a/pychunkedgraph/app/app_utils.py b/pychunkedgraph/app/app_utils.py index 061f60115..9d69c3650 100644 --- a/pychunkedgraph/app/app_utils.py +++ b/pychunkedgraph/app/app_utils.py @@ -16,6 +16,7 @@ from pychunkedgraph.graph import ChunkedGraph from pychunkedgraph.graph import get_default_client_info from pychunkedgraph.graph import exceptions as cg_exceptions +from pychunkedgraph.graph.utils.generic import lookup_svs_from_seg PCG_CACHE = {} @@ -238,6 +239,10 @@ def ccs(coordinates_nm_): f"{coordinates} - Validation stage." ) + # Fast path: all node_ids are L1 and OCDBT — single seg read for all coords + if cg.meta.ocdbt_seg and np.all(cg.get_chunk_layers(np.unique(node_ids)) == 1): + return lookup_svs_from_seg(cg.meta, coordinates) + atomic_ids = np.zeros(len(coordinates), dtype=np.uint64) for node_id in np.unique(node_ids): node_id_m = node_ids == node_id diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index da037fc34..695d5aef9 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -444,6 +444,7 @@ def handle_split(table_id): cg = app_utils.get_cg(table_id, skip_cache=True) current_app.logger.debug(data) sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) + current_app.logger.info(f"sv_lookup pre-split: sources={sources}, sinks={sinks}") try: ret = cg.remove_edges( user_id=user_id, @@ -480,6 +481,9 @@ def handle_split(table_id): ) sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) + current_app.logger.info( + f"sv_lookup post-split: sources={sources}, sinks={sinks}" + ) ret = cg.remove_edges( user_id=user_id, source_ids=sources, diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 3a6b1461d..c320c1bde 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -183,16 +183,21 @@ def get_atomic_ids_from_coords( :param max_dist_nm: max distance explored :return: supervoxel ids; returns None if no solution was found """ - if self.get_chunk_layer(parent_id) == 1: + if self.get_chunk_layer(parent_id) == 1 and not self.meta.ocdbt_seg: return np.array([parent_id] * len(coordinates), dtype=np.uint64) - # Enable search with old parent by using its timestamp and map to parents - parent_ts = self.get_node_timestamps([parent_id], return_numpy=False)[0] + layer = self.get_chunk_layer(parent_id) + # L1 nodes don't have children, skip timestamp lookup + parent_ts = ( + None + if layer == 1 + else self.get_node_timestamps([parent_id], return_numpy=False)[0] + ) return id_helpers.get_atomic_ids_from_coords( self.meta, coordinates, parent_id, - self.get_chunk_layer(parent_id), + layer, parent_ts, self.get_roots, max_dist_nm, diff --git a/pychunkedgraph/graph/cutting_sv.py b/pychunkedgraph/graph/cutting_sv.py index 5f9ba58c5..bafd86c68 100644 --- a/pychunkedgraph/graph/cutting_sv.py +++ b/pychunkedgraph/graph/cutting_sv.py @@ -4,7 +4,6 @@ from typing import Dict, Tuple, Optional, Sequence from scipy.spatial import cKDTree - # EDT backends: prefer Seung-Lab edt, fallback to scipy.ndimage try: from edt import edt as _edt_fast @@ -385,7 +384,7 @@ def connect_both_seeds_via_ridge( refine_fullres_when_fail: bool = True, snap_method: str = "kdtree", snap_kwargs: dict | None = None, - verbose: bool = True, + verbose: bool = False, ): def log(msg: str): if verbose: @@ -726,7 +725,7 @@ def split_supervoxel_growing( snap_method: str = "kdtree", snap_kwargs: dict | None = None, # logging - verbose: bool = True, + verbose: bool = False, ): def log(msg: str): if verbose: diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index dc5641ca9..2b917c230 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -72,11 +72,6 @@ def _update_chunk(args): x, y, z = chunk_coord chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) - # TODO: remove these 3 lines, testing only - rr = cg.range_read_chunk(chunk_id) - max_node_id = max(rr.keys()) - cg.id_client.set_max_node_id(chunk_id, max_node_id) - _s, _e = chunk_bbox - bb_start og_chunk_seg = seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] chunk_seg = result_seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] @@ -91,16 +86,15 @@ def _update_chunk(args): _label_id_map = {} for _id in labels: _mask = chunk_seg == _id - if np.any(_mask): - _idx = np.unravel_index(np.flatnonzero(_mask)[0], og_chunk_seg.shape) - _og_value = og_chunk_seg[_idx] - _index = np.argwhere(_mask) - _indices.append(_index) - _ones = np.ones(len(_index), dtype=basetypes.NODE_ID) - _old_values.append(_ones * _og_value) - new_id = cg.id_client.create_node_id(chunk_id) - _new_values.append(_ones * new_id) - _label_id_map[int(_id)] = new_id + _idx = np.unravel_index(np.flatnonzero(_mask)[0], og_chunk_seg.shape) + _og_value = og_chunk_seg[_idx] + _index = np.argwhere(_mask) + _indices.append(_index) + _ones = np.ones(len(_index), dtype=basetypes.NODE_ID) + _old_values.append(_ones * _og_value) + new_id = cg.id_client.create_node_id(chunk_id) + _new_values.append(_ones * new_id) + _label_id_map[int(_id)] = new_id _indices = np.concatenate(_indices) + (chunk_bbox[0] - bb_start) _old_values = np.concatenate(_old_values) @@ -112,7 +106,6 @@ def _voxel_crop(bbs, bbe, bbs_, bbe_): xS, yS, zS = bbs - bbs_ xE, yE, zE = (None if i == 0 else -1 for i in bbe_ - bbe) voxel_overlap_crop = np.s_[xS:xE, yS:yE, zS:zE] - logging.info(f"voxel_overlap_crop: {voxel_overlap_crop}") return voxel_overlap_crop @@ -130,7 +123,6 @@ def _parse_results(results, seg, bbs, bbe): assert np.all(seg.shape == bbe - bbs), f"{seg.shape} != {bbe - bbs}" slices = tuple(slice(start, end) for start, end in zip(bbs, bbe)) + (slice(None),) - logging.info(f"slices {slices}") return seg, old_new_map, slices, new_id_label_map @@ -176,6 +168,42 @@ def _match_by_proximity(new_ids, partner, aff, area, distances_row, threshold): ) +def _match_inf_unsplit(new_ids, partner, aff, area, distances_row): + """Inf-affinity edge to an unsplit partner: assign to closest fragment only. + Connecting all fragments would create an uncuttable bridge between source/sink sides. + """ + closest = new_ids[np.argmin(distances_row)] + return ( + np.array([[closest, partner]], dtype=np.uint64), + np.array([aff], dtype=basetypes.EDGE_AFFINITY), + np.array([area], dtype=basetypes.EDGE_AREA), + ) + + +def _match_partner( + new_ids, partner, aff, area, distances_row, new_id_label_map, threshold +): + """Route a single old edge to the appropriate new fragment(s).""" + if np.isinf(aff): + if new_id_label_map and partner in new_id_label_map: + return _match_by_label( + new_ids, partner, aff, area, new_id_label_map, distances_row + ) + return _match_inf_unsplit(new_ids, partner, aff, area, distances_row) + return _match_by_proximity(new_ids, partner, aff, area, distances_row, threshold) + + +def _expand_partners(active_partners, active_affs, active_areas, old_new_map): + """If a partner was also split, expand it to its new fragment IDs.""" + partners, affs, areas = [], [], [] + for i in range(len(active_partners)): + remapped = old_new_map.get(active_partners[i], [active_partners[i]]) + partners.extend(remapped) + affs.extend([active_affs[i]] * len(remapped)) + areas.extend([active_areas[i]] * len(remapped)) + return partners, affs, areas + + def _get_new_edges( edges_info: tuple, sv_ids: np.ndarray, @@ -190,7 +218,6 @@ def _get_new_edges( edges, affinities, areas = edges_info for old, new in old_new_map.items(): - logging.info(f"old and new {old, new}") new_ids = np.array(list(new), dtype=basetypes.NODE_ID) edges_m = np.any(edges == old, axis=1) selected_edges = edges[edges_m] @@ -198,68 +225,41 @@ def _get_new_edges( assert np.all(np.sum(sel_m, axis=1) == 1) partners = selected_edges[sel_m] + edge_affs = affinities[edges_m] + edge_areas = areas[edges_m] active_m = np.isin(partners, sv_ids) - logging.info(f"sv_ids: {np.sum(sv_ids > 0)}") - logging.info(f"edges: {edges.shape} {np.sum(edges_m)} {np.sum(sel_m)}") - logging.info(f"selected_edges: {selected_edges.shape}") + # Inactive partners (different root, outside distance map): all fragments get the edge + for k in np.where(~active_m)[0]: + for new_id in new_ids: + new_edges.append(np.array([new_id, partners[k]], dtype=np.uint64)) + new_affs.append(edge_affs[k]) + new_areas.append(edge_areas[k]) - # inactive - for new_id in new_ids: - _a = [[new_id] * np.sum(~active_m), partners[~active_m]] - new_edges.extend(np.array(_a, dtype=np.uint64).T) - new_affs.extend(affinities[edges_m][np.any(sel_m, axis=1)][~active_m]) - new_areas.extend(areas[edges_m][np.any(sel_m, axis=1)][~active_m]) - - # active - active_partners_ = partners[active_m] - active_affs_ = affinities[edges_m][np.any(sel_m, axis=1)][active_m] - active_areas_ = areas[edges_m][np.any(sel_m, axis=1)][active_m] - - logging.info(f"partners: {partners.shape} {active_partners_.shape}") - - active_partners = [] - active_affs = [] - active_areas = [] - for i in range(len(active_partners_)): - remapped_ = old_new_map.get(active_partners_[i], [active_partners_[i]]) - active_partners.extend(remapped_) - active_affs.extend([active_affs_[i]] * len(remapped_)) - active_areas.extend([active_areas_[i]] * len(remapped_)) - - logging.info(f"new_ids, active_partners: {new_ids, len(active_partners)}") - logging.info(f"new_dist_vec(new_ids): {new_dist_vec(new_ids)}") - logging.info(f"dist_vec(active_partners): {dist_vec(active_partners)}") - distances_ = distances[new_dist_vec(new_ids)][:, dist_vec(active_partners)].T - for i, partner in enumerate(active_partners): - aff = active_affs[i] - if np.isinf(aff) and new_id_label_map and partner in new_id_label_map: - e, a, ar = _match_by_label( - new_ids, - partner, - aff, - active_areas[i], - new_id_label_map, - distances_[i], - ) - else: - e, a, ar = _match_by_proximity( - new_ids, - partner, - aff, - active_areas[i], - distances_[i], - threshold, - ) + # Active partners (same root): route based on affinity type + active_partners, act_affs, act_areas = _expand_partners( + partners[active_m], edge_affs[active_m], edge_areas[active_m], old_new_map + ) + new_id_rows = new_dist_vec(new_ids) + act_dists = distances[new_id_rows][:, dist_vec(active_partners)].T + for k, partner in enumerate(active_partners): + e, a, ar = _match_partner( + new_ids, + partner, + act_affs[k], + act_areas[k], + act_dists[k], + new_id_label_map, + threshold, + ) new_edges.extend(e) new_affs.extend(a) new_areas.extend(ar) - # edges between split fragments + # Low-affinity edges between split fragments (cuttable by mincut) for i in range(len(new_ids)): - for j in range(i + 1, len(new_ids)): # includes no selfedges - _a = [new_ids[i], new_ids[j]] - new_edges.append(np.array(_a, dtype=np.uint64)) + for j in range(i + 1, len(new_ids)): + new_edges.append(np.array([new_ids[i], new_ids[j]], dtype=np.uint64)) new_affs.append(0.001) new_areas.append(0) @@ -269,11 +269,11 @@ def _get_new_edges( np.array([], dtype=basetypes.EDGE_AFFINITY), np.array([], dtype=basetypes.EDGE_AREA), ) - affinites = np.array(new_affs, dtype=basetypes.EDGE_AFFINITY) - areas = np.array(new_areas, dtype=basetypes.EDGE_AREA) - edges = np.sort(np.array(new_edges, dtype=basetypes.NODE_ID), axis=1) - edges, idx = np.unique(edges, return_index=True, axis=0) - return edges, affinites[idx], areas[idx] + affinities_ = np.array(new_affs, dtype=basetypes.EDGE_AFFINITY) + areas_ = np.array(new_areas, dtype=basetypes.EDGE_AREA) + edges_ = np.sort(np.array(new_edges, dtype=basetypes.NODE_ID), axis=1) + edges_, idx = np.unique(edges_, return_index=True, axis=0) + return edges_, affinities_[idx], areas_[idx] def _update_edges( @@ -304,8 +304,6 @@ def _update_edges( edges = edges[edges_idx] affinities = affinities[edges_idx] areas = areas[edges_idx] - logging.info(f"edges.shape, affinities.shape {edges.shape, affinities.shape}") - new_ids = np.array(list(set.union(*old_new_map.values())), dtype=basetypes.NODE_ID) new_kdtrees = [kdtrees[k] for k in new_ids] new_disance_map = dict(zip(new_ids, np.arange(len(new_ids)))) @@ -360,7 +358,7 @@ def split_supervoxel( source_coords: np.ndarray, sink_coords: np.ndarray, operation_id: int, - verbose: bool = True, + verbose: bool = False, time_stamp: datetime = None, ) -> dict[int, set]: """ @@ -386,15 +384,13 @@ def split_supervoxel( cut_supervoxels = _get_whole_sv(cg, sv_id, min_coord=chunk_min, max_coord=chunk_max) supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) - logging.info(f"whole sv {sv_id} -> {cut_supervoxels}") + logging.info(f"whole sv {sv_id} -> {supervoxel_ids.tolist()}") # one voxel overlap for neighbors bbs_ = np.clip(bbs - 1, vol_start, vol_end) bbe_ = np.clip(bbe + 1, vol_start, vol_end) seg = get_local_segmentation(cg.meta, bbs_, bbe_).squeeze() binary_seg = np.isin(seg, supervoxel_ids) - logging.info(f"{seg.shape}; {binary_seg.shape}; {bbs, bbe}; {bbs_, bbe_}") - voxel_overlap_crop = _voxel_crop(bbs, bbe, bbs_, bbe_) split_result = split_supervoxel_helper( binary_seg[voxel_overlap_crop], @@ -418,25 +414,22 @@ def split_supervoxel( results, seg_cropped, bbs, bbe ) - seg_roots = seg.copy() sv_ids = fastremap.unique(seg) roots = cg.get_roots(sv_ids) - seg_roots = fastremap.remap(seg_roots, dict(zip(sv_ids, roots)), in_place=True) + sv_root_map = dict(zip(sv_ids, roots)) + root = sv_root_map[sv_id] + logging.info(f"{sv_id} -> {root}") - root = cg.get_root(sv_id) - logging.info(f"{sv_id} root = {root}") - - seg_masked = seg.copy() - seg_masked[seg_roots != root] = 0 - sv_ids = fastremap.unique(seg_masked) - - seg_masked[voxel_overlap_crop] = new_seg + root_mask = fastremap.remap(seg, sv_root_map, in_place=False) == root + seg[~root_mask] = 0 + sv_ids = fastremap.unique(seg) + seg[voxel_overlap_crop] = new_seg edges_tuple = _update_edges( cg, sv_ids, root, np.array([bbs, bbe]), - seg_masked, + seg, old_new_map, new_id_label_map, ) @@ -502,11 +495,9 @@ def copy_parents_and_add_lineage( for parent, children_cells in children_cells_map.items(): assert len(children_cells) == 1, children_cells for cell in children_cells: - logging.info(f"{parent}: {cell.value}") mask = np.isin(cell.value, list(old_new_map.keys())) replace = np.concatenate([old_new_map[x] for x in cell.value[mask]]) children = np.concatenate([cell.value[~mask], replace]) - logging.info(f"{parent}: {children}") cg.cache.children_cache[parent] = children result.append( cg.client.mutate_row( diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py index 0b5cf5c5c..e61356a1e 100644 --- a/pychunkedgraph/graph/utils/generic.py +++ b/pychunkedgraph/graph/utils/generic.py @@ -153,7 +153,6 @@ def get_parents_at_timestamp(nodes, parents_ts_map, time_stamp, unique: bool = F return list(parents), skipped_nodes - def get_local_segmentation(meta, bbox_start, bbox_end) -> np.ndarray: result = None xL, yL, zL = bbox_start @@ -163,3 +162,12 @@ def get_local_segmentation(meta, bbox_start, bbox_end) -> np.ndarray: else: result = meta.cv[xL:xH, yL:yH, zL:zH] return result + + +def lookup_svs_from_seg(meta, coordinates): + """Read SV IDs directly from OCDBT segmentation at given coordinates.""" + bbox_start = np.min(coordinates, axis=0) + bbox_end = np.max(coordinates, axis=0) + 1 + seg = get_local_segmentation(meta, bbox_start, bbox_end)[..., 0] + local_coords = coordinates - bbox_start + return np.array([seg[tuple(c)] for c in local_coords], dtype=np.uint64) diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py index 43faf2160..7f7d8f927 100644 --- a/pychunkedgraph/graph/utils/id_helpers.py +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -10,7 +10,7 @@ import numpy as np from pychunkedgraph.graph import basetypes -from .generic import get_local_segmentation +from .generic import get_local_segmentation, lookup_svs_from_seg from ..meta import ChunkedGraphMeta from ..chunks import utils as chunk_utils @@ -128,7 +128,10 @@ def get_atomic_ids_from_coords( """ import fastremap - if parent_id_layer == 1: + if parent_id_layer == 1 and meta.ocdbt_seg: + return lookup_svs_from_seg(meta, coordinates) + + if parent_id_layer == 1 and not meta.ocdbt_seg: return np.array([parent_id] * len(coordinates), dtype=np.uint64) coordinates_nm = coordinates * np.array(meta.resolution) diff --git a/pychunkedgraph/ingest/upgrade/atomic_layer.py b/pychunkedgraph/ingest/upgrade/atomic_layer.py index 81122c5a8..c767ca124 100644 --- a/pychunkedgraph/ingest/upgrade/atomic_layer.py +++ b/pychunkedgraph/ingest/upgrade/atomic_layer.py @@ -146,6 +146,12 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int]): fix_corrupt_nodes(cg, corrupt_nodes, CHILDREN) return + # Set max node ID for the L1 chunk (needed for SV splitting to create new IDs) + l1_chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) + if CHILDREN: + all_svs = np.concatenate(list(CHILDREN.values())) + cg.id_client.set_max_node_id(l1_chunk_id, np.max(all_svs)) + cg.copy_fake_edges(chunk_id) if len(nodes) == 0: return diff --git a/pychunkedgraph/tests/graph/test_utils_id_helpers.py b/pychunkedgraph/tests/graph/test_utils_id_helpers.py index ab4afa60d..c347b9bbd 100644 --- a/pychunkedgraph/tests/graph/test_utils_id_helpers.py +++ b/pychunkedgraph/tests/graph/test_utils_id_helpers.py @@ -153,6 +153,7 @@ def test_layer1_returns_parent_id(self): meta = MagicMock() meta.data_source.CV_MIP = 0 meta.resolution = np.array([1, 1, 1]) + meta.ocdbt_seg = False parent_id = np.uint64(42) coordinates = np.array([[10, 20, 30], [40, 50, 60]]) diff --git a/requirements.in b/requirements.in index 2d8112537..c6d241ff1 100644 --- a/requirements.in +++ b/requirements.in @@ -26,7 +26,7 @@ middle-auth-client>=3.11.0 zmesh>=1.7.0 fastremap>=1.14.0 task-queue>=2.14.0 -messagingclient +messagingclient>0.3.0 dracopy>=1.5.0 datastoreflex>=0.5.0 kvdbclient>=0.4.0 diff --git a/requirements.txt b/requirements.txt index 5df78e9f8..f5f8872df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -206,7 +206,7 @@ markupsafe==3.0.3 # flask # jinja2 # werkzeug -messagingclient==0.3.0 +messagingclient==0.4.0 # via -r requirements.in microviewer==1.20.0 # via cloud-volume diff --git a/uwsgi.ini b/uwsgi.ini index 776e2ff00..9440db38e 100644 --- a/uwsgi.ini +++ b/uwsgi.ini @@ -82,32 +82,8 @@ harakiri-verbose = true ### Logging -# Filter our properly pre-formated app messages and pass them through logger = app stdio -log-route = app ^{.*"source":.*}$ - -# Capture known / most common uWSGI messages -logger = uWSGIdebug stdio -logger = uWSGIwarn stdio - -log-route = uWSGIdebug ^{address space usage -log-route = uWSGIwarn \[warn\] - -log-encoder = json:uWSGIdebug {"source":"uWSGI","time":"${strftime:%Y-%m-%dT%H:%M:%S.000Z}","severity":"debug","message":"${msg}"} -log-encoder = nl:uWSGIdebug -log-encoder = json:uWSGIwarn {"source":"uWSGI","time":"${strftime:%Y-%m-%dT%H:%M:%S.000Z}","severity":"warning","message":"${msg}"} -log-encoder = nl:uWSGIwarn - -# Treat everything else as error message of unknown origin -logger = unknown stdio - -# Creating our own "inverse Regex" using negative lookaheads, which makes this -# log-route rather cryptic and slow... Unclear how to get a simple -# "fall-through" behavior for non-matching messages, otherwise. -log-route = unknown ^(?:(?!^{address space usage|\[warn\]|^{.*"source".*}$).)*$ - -log-encoder = json:unknown {"source":"unknown","time":"${strftime:%Y-%m-%dT%H:%M:%S.000Z}","severity":"error","message":"${msg}"} -log-encoder = nl:unknown +log-route = app .* log-4xx = true log-5xx = true From e7db53d3c32e13bf3edc1f6b4cf9c359118e54c8 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Fri, 20 Mar 2026 16:56:58 +0000 Subject: [PATCH 15/16] remove multiwrapper usage and references --- pychunkedgraph/graph/subgraph.py | 28 +++-- pychunkedgraph/meshing/mesh_io.py | 71 ++++++----- pychunkedgraph/meshing/meshengine.py | 155 ++++++++++++++---------- pychunkedgraph/meshing/meshgen.py | 20 ++- pychunkedgraph/meshing/meshlabserver.py | 34 +++--- requirements.in | 1 - requirements.txt | 3 - 7 files changed, 183 insertions(+), 129 deletions(-) diff --git a/pychunkedgraph/graph/subgraph.py b/pychunkedgraph/graph/subgraph.py index 1538b3cc2..4f21f2489 100644 --- a/pychunkedgraph/graph/subgraph.py +++ b/pychunkedgraph/graph/subgraph.py @@ -186,8 +186,8 @@ def _get_subgraph_multiple_nodes( return_flattened: bool = False, ): from collections import ChainMap - from multiwrapper.multiprocessing_utils import n_cpus - from multiwrapper.multiprocessing_utils import multithread_func + import os + from concurrent.futures import ThreadPoolExecutor from .utils.generic import mask_nodes_by_bounding_box @@ -223,20 +223,26 @@ def _get_subgraph_multiple_nodes_threaded( subgraph = SubgraphProgress(cg.meta, node_ids, return_layers, serializable) while not subgraph.done_processing(): - this_n_threads = min([int(len(subgraph.cur_nodes) // 50000) + 1, n_cpus]) - cur_nodes_child_maps = multithread_func( - _get_subgraph_multiple_nodes_threaded, - np.array_split(subgraph.cur_nodes, this_n_threads), - n_threads=this_n_threads, - debug=this_n_threads == 1, + this_n_threads = min( + [int(len(subgraph.cur_nodes) // 50000) + 1, os.cpu_count()] ) + batches = np.array_split(subgraph.cur_nodes, this_n_threads) + if this_n_threads == 1: + cur_nodes_child_maps = [ + _get_subgraph_multiple_nodes_threaded(b) for b in batches + ] + else: + with ThreadPoolExecutor(max_workers=this_n_threads) as executor: + cur_nodes_child_maps = list( + executor.map(_get_subgraph_multiple_nodes_threaded, batches) + ) cur_nodes_children = dict(ChainMap(*cur_nodes_child_maps)) subgraph.process_batch_of_children(cur_nodes_children) if return_flattened and len(return_layers) == 1: for node_id in node_ids: - subgraph.node_to_subgraph[ - _get_dict_key(node_id) - ] = subgraph.node_to_subgraph[_get_dict_key(node_id)][return_layers[0]] + subgraph.node_to_subgraph[_get_dict_key(node_id)] = ( + subgraph.node_to_subgraph[_get_dict_key(node_id)][return_layers[0]] + ) return subgraph.node_to_subgraph diff --git a/pychunkedgraph/meshing/mesh_io.py b/pychunkedgraph/meshing/mesh_io.py index 1cf1fed66..4a6eac7c4 100644 --- a/pychunkedgraph/meshing/mesh_io.py +++ b/pychunkedgraph/meshing/mesh_io.py @@ -7,19 +7,23 @@ import networkx as nx import cloudvolume -from multiwrapper import multiprocessing_utils as mu +from concurrent.futures import ProcessPoolExecutor + def read_mesh_h5(): pass + def write_mesh_h5(): pass + def read_obj(path): return Mesh(path) + def _download_meshes_thread(args): - """ Downloads meshes into target directory + """Downloads meshes into target directory :param args: list """ @@ -33,7 +37,7 @@ def _download_meshes_thread(args): def download_meshes(seg_ids, target_dir, cv_path, n_threads=1): - """ Downloads meshes in target directory (parallel) + """Downloads meshes in target directory (parallel) :param seg_ids: list of ints :param target_dir: str @@ -52,12 +56,11 @@ def download_meshes(seg_ids, target_dir, cv_path, n_threads=1): multi_args.append([seg_id_block, cv_path, target_dir]) if n_jobs == 1: - mu.multiprocess_func(_download_meshes_thread, - multi_args, debug=True, - verbose=True, n_threads=1) + for args in multi_args: + _download_meshes_thread(args) else: - mu.multisubprocess_func(_download_meshes_thread, - multi_args, n_threads=n_threads) + with ProcessPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(_download_meshes_thread, multi_args)) def refine_mesh(): @@ -77,6 +80,7 @@ def mesh(self, filename): return self.filename_dict[filename] + class Mesh(object): def __init__(self, filename): self._vertices = [] @@ -117,8 +121,9 @@ def normals(self): @property def edges(self): if self._edges is None: - self._edges = np.concatenate([self.faces[:, :2], - self.faces[:, 1:3]], axis=0) + self._edges = np.concatenate( + [self.faces[:, :2], self.faces[:, 1:3]], axis=0 + ) return self._edges @property @@ -141,21 +146,23 @@ def load_obj(self): normals = [] for line in open(self.filename, "r"): - if line.startswith('#'): continue + if line.startswith("#"): + continue values = line.split() - if not values: continue - if values[0] == 'v': + if not values: + continue + if values[0] == "v": v = values[1:4] vertices.append(v) - elif values[0] == 'vn': + elif values[0] == "vn": v = map(float, values[1:4]) normals.append(v) - elif values[0] == 'f': + elif values[0] == "f": face = [] texcoords = [] norms = [] for v in values[1:]: - w = v.split('/') + w = v.split("/") face.append(int(w[0])) if len(w) >= 2 and len(w[1]) > 0: texcoords.append(int(w[1])) @@ -191,7 +198,8 @@ def write_vertices_ply(self, out_fname, coords=None): tweaked_array = np.array( list(zip(coords[:, 0], coords[:, 1], coords[:, 2])), - dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")], + ) vertex_element = plyfile.PlyElement.describe(tweaked_array, "vertex") @@ -200,8 +208,15 @@ def write_vertices_ply(self, out_fname, coords=None): plyfile.PlyData([vertex_element]).write(out_fname) - def get_local_view(self, n_points, pc_align=False, center_node_id=None, - center_coord=None, method="kdtree", verbose=False): + def get_local_view( + self, + n_points, + pc_align=False, + center_node_id=None, + center_coord=None, + method="kdtree", + verbose=False, + ): if center_node_id is None and center_coord is None: center_node_id = np.random.randint(len(self.vertices)) @@ -215,11 +230,11 @@ def get_local_view(self, n_points, pc_align=False, center_node_id=None, if verbose: print(np.mean(dists), np.max(dists), np.min(dists)) elif method == "graph": - dist_dict = nx.single_source_dijkstra_path_length(self.graph, - center_node_id, - weight="weight") - sorting = np.argsort(np.array(list(dist_dict.values()))) - node_ids = np.array(list(dist_dict.keys()))[sorting[:n_points]] + dist_dict = nx.single_source_dijkstra_path_length( + self.graph, center_node_id, weight="weight" + ) + sorting = np.argsort(np.array(list(dist_dict.values()))) + node_ids = np.array(list(dist_dict.keys()))[sorting[:n_points]] else: raise Exception("unknow method") @@ -236,7 +251,9 @@ def calc_pc_align(self, vertices): return pca.transform(vertices) def create_nx_graph(self): - weights = np.linalg.norm(self.vertices[self.edges[:, 0]] - self.vertices[self.edges[:, 1]], axis=1) + weights = np.linalg.norm( + self.vertices[self.edges[:, 0]] - self.vertices[self.edges[:, 1]], axis=1 + ) print(weights.shape) @@ -244,8 +261,6 @@ def create_nx_graph(self): weighted_graph.add_edges_from(self.edges) for i_edge, edge in enumerate(self.edges): - weighted_graph[edge[0]][edge[1]]['weight'] = weights[i_edge] + weighted_graph[edge[0]][edge[1]]["weight"] = weights[i_edge] return weighted_graph - - diff --git a/pychunkedgraph/meshing/meshengine.py b/pychunkedgraph/meshing/meshengine.py index e852dfa3a..3f86fd7b3 100644 --- a/pychunkedgraph/meshing/meshengine.py +++ b/pychunkedgraph/meshing/meshengine.py @@ -3,19 +3,21 @@ import itertools import random +from concurrent.futures import ProcessPoolExecutor from pychunkedgraph.graph import chunkedgraph -from multiwrapper import multiprocessing_utils as mu from . import meshgen class MeshEngine(object): - def __init__(self, - table_id: str, - instance_id: str = "pychunkedgraph", - project_id: str = "neuromancer-seung-import", - mesh_mip: int = 3, - highest_mesh_layer: int = 5): + def __init__( + self, + table_id: str, + instance_id: str = "pychunkedgraph", + project_id: str = "neuromancer-seung-import", + mesh_mip: int = 3, + highest_mesh_layer: int = 5, + ): self._table_id = table_id self._instance_id = instance_id @@ -62,7 +64,8 @@ def cg(self): self._cg = chunkedgraph.ChunkedGraph( table_id=self.table_id, instance_id=self.instance_id, - project_id=self.project_id) + project_id=self.project_id, + ) return self._cg @property @@ -80,8 +83,9 @@ def cv(self): self._cv.info["mesh"] = self.cv_mesh_dir return self._cv - def mesh_multiple_layers(self, layers=None, bounding_box=None, - block_factor=2, n_threads=128): + def mesh_multiple_layers( + self, layers=None, bounding_box=None, block_factor=2, n_threads=128 + ): if layers is None: layers = range(1, int(self.cg.n_layers + 1)) @@ -94,28 +98,30 @@ def mesh_multiple_layers(self, layers=None, bounding_box=None, for layer in layers: print("Now: layer %d" % layer) - self.mesh_single_layer(layer, bounding_box=bounding_box, - block_factor=block_factor, - n_threads=n_threads) - - def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, - n_threads=128): + self.mesh_single_layer( + layer, + bounding_box=bounding_box, + block_factor=block_factor, + n_threads=n_threads, + ) + + def mesh_single_layer( + self, layer, bounding_box=None, block_factor=2, n_threads=128 + ): assert layer <= self.highest_mesh_layer dataset_bounding_box = np.array(self.cv.bounds.to_list()) - block_bounding_box_cg = \ - [np.floor(dataset_bounding_box[:3] / - self.cg.chunk_size).astype(int), - np.ceil(dataset_bounding_box[3:] / - self.cg.chunk_size).astype(int)] + block_bounding_box_cg = [ + np.floor(dataset_bounding_box[:3] / self.cg.chunk_size).astype(int), + np.ceil(dataset_bounding_box[3:] / self.cg.chunk_size).astype(int), + ] if bounding_box is not None: - bounding_box_cg = \ - [np.floor(bounding_box[0] / - self.cg.chunk_size).astype(int), - np.ceil(bounding_box[1] / - self.cg.chunk_size).astype(int)] + bounding_box_cg = [ + np.floor(bounding_box[0] / self.cg.chunk_size).astype(int), + np.ceil(bounding_box[1] / self.cg.chunk_size).astype(int), + ] m = block_bounding_box_cg[0] < bounding_box_cg[0] block_bounding_box_cg[0][m] = bounding_box_cg[0][m] @@ -126,31 +132,37 @@ def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, block_bounding_box_cg /= 2 ** np.max([0, layer - 2]) block_bounding_box_cg = np.ceil(block_bounding_box_cg) - n_jobs = np.prod(block_bounding_box_cg[1] - - block_bounding_box_cg[0]) / \ - block_factor ** 2 < n_threads + n_jobs = ( + np.prod(block_bounding_box_cg[1] - block_bounding_box_cg[0]) + / block_factor**2 + < n_threads + ) while n_jobs < n_threads and block_factor > 1: block_factor -= 1 - n_jobs = np.prod(block_bounding_box_cg[1] - - block_bounding_box_cg[0]) / \ - block_factor ** 2 < n_threads - - block_iter = itertools.product(np.arange(block_bounding_box_cg[0][0], - block_bounding_box_cg[1][0], - block_factor), - np.arange(block_bounding_box_cg[0][1], - block_bounding_box_cg[1][1], - block_factor), - np.arange(block_bounding_box_cg[0][2], - block_bounding_box_cg[1][2], - block_factor)) + n_jobs = ( + np.prod(block_bounding_box_cg[1] - block_bounding_box_cg[0]) + / block_factor**2 + < n_threads + ) + + block_iter = itertools.product( + np.arange( + block_bounding_box_cg[0][0], block_bounding_box_cg[1][0], block_factor + ), + np.arange( + block_bounding_box_cg[0][1], block_bounding_box_cg[1][1], block_factor + ), + np.arange( + block_bounding_box_cg[0][2], block_bounding_box_cg[1][2], block_factor + ), + ) blocks = np.array(list(block_iter), dtype=int) cg_info = self.cg.get_serialized_info() - del (cg_info['credentials']) + del cg_info["credentials"] multi_args = [] for start_block in blocks: @@ -158,44 +170,57 @@ def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, m = end_block > block_bounding_box_cg[1] end_block[m] = block_bounding_box_cg[1][m] - multi_args.append([cg_info, start_block, end_block, self.cg._cv_path, - self.cv_mesh_dir, self.mesh_mip, layer]) + multi_args.append( + [ + cg_info, + start_block, + end_block, + self.cg._cv_path, + self.cv_mesh_dir, + self.mesh_mip, + layer, + ] + ) random.shuffle(multi_args) random.shuffle(multi_args) # Run parallelizing if n_threads == 1: - mu.multiprocess_func(meshgen._mesh_layer_thread, multi_args, - n_threads=n_threads, verbose=True, - debug=n_threads == 1) + for args in multi_args: + meshgen._mesh_layer_thread(args) else: - mu.multisubprocess_func(meshgen._mesh_layer_thread, multi_args, - n_threads=n_threads, - suffix="%s_%d" % (self.table_id, layer)) + with ProcessPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(meshgen._mesh_layer_thread, multi_args)) def create_manifests_for_higher_layers(self, n_threads=1): root_id_max = self.cg.get_max_node_id( - self.cg.get_chunk_id(layer=int(self.cg.n_layers), - x=int(0), y=int(0), - z=int(0))) + self.cg.get_chunk_id( + layer=int(self.cg.n_layers), x=int(0), y=int(0), z=int(0) + ) + ) - root_id_blocks = np.linspace(1, root_id_max, n_threads*3).astype(int) + root_id_blocks = np.linspace(1, root_id_max, n_threads * 3).astype(int) cg_info = self.cg.get_serialized_info() - del (cg_info['credentials']) + del cg_info["credentials"] multi_args = [] for i_block in range(len(root_id_blocks) - 1): - multi_args.append([cg_info, self.cv_path, self.cv_mesh_dir, - root_id_blocks[i_block], - root_id_blocks[i_block + 1], - self.highest_mesh_layer]) + multi_args.append( + [ + cg_info, + self.cv_path, + self.cv_mesh_dir, + root_id_blocks[i_block], + root_id_blocks[i_block + 1], + self.highest_mesh_layer, + ] + ) # Run parallelizing if n_threads == 1: - mu.multiprocess_func(meshgen._create_manifest_files_thread, - multi_args, n_threads=n_threads, verbose=True, - debug=n_threads == 1) + for args in multi_args: + meshgen._create_manifest_files_thread(args) else: - mu.multisubprocess_func(meshgen._create_manifest_files_thread, - multi_args, n_threads=n_threads) + with ProcessPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(meshgen._create_manifest_files_thread, multi_args)) diff --git a/pychunkedgraph/meshing/meshgen.py b/pychunkedgraph/meshing/meshgen.py index f6613f7d2..ea6410dfd 100644 --- a/pychunkedgraph/meshing/meshgen.py +++ b/pychunkedgraph/meshing/meshgen.py @@ -10,7 +10,7 @@ import pytz from scipy import ndimage -from multiwrapper import multiprocessing_utils as mu +from concurrent.futures import ThreadPoolExecutor from cloudfiles import CloudFiles from cloudvolume import CloudVolume from cloudvolume.datasource.precomputed.sharding import ShardingSpecification @@ -23,7 +23,6 @@ from pychunkedgraph.meshing import meshgen_utils # noqa from pychunkedgraph.meshing.manifest.cache import ManifestCache - UTC = pytz.UTC # Change below to true if debugging and want to see results in stdout @@ -263,7 +262,12 @@ def _get_root_ids(args): multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) if n_jobs > 0: - mu.multithread_func(_get_root_ids, multi_args, n_threads=n_threads) + if n_threads == 1: + for args in multi_args: + _get_root_ids(args) + else: + with ThreadPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(_get_root_ids, multi_args)) return lx_ids, np.array(root_ids), lx_id_remap @@ -443,7 +447,12 @@ def _get_root_ids(args): multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) if n_jobs > 0: - mu.multithread_func(_get_root_ids, multi_args, n_threads=n_threads) + if n_threads == 1: + for args in multi_args: + _get_root_ids(args) + else: + with ThreadPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(_get_root_ids, multi_args)) sv_ids_index = len(node_ids) chunk_ids_index = len(node_ids) + len(sv_ids) @@ -1040,7 +1049,8 @@ def get_multi_child_nodes(cg, chunk_id, node_id_subset=None, chunk_bbox_string=F fragment.value for child_fragments_for_node in node_rows for fragment in child_fragments_for_node - ], dtype=object + ], + dtype=object, ) # Filter out node ids that do not have roots (caused by failed ingest tasks) root_ids = cg.get_roots(node_ids, fail_to_zero=True) diff --git a/pychunkedgraph/meshing/meshlabserver.py b/pychunkedgraph/meshing/meshlabserver.py index 3065d6707..65c7439db 100644 --- a/pychunkedgraph/meshing/meshlabserver.py +++ b/pychunkedgraph/meshing/meshlabserver.py @@ -3,7 +3,7 @@ import glob import numpy as np -from multiwrapper import multiprocessing_utils as mu +from concurrent.futures import ProcessPoolExecutor HOME = os.path.expanduser("~") @@ -12,17 +12,19 @@ def run_meshlab_script(script_name, arg_dict): - """ Runs meshlabserver script --headless + """Runs meshlabserver script --headless No X-Server required :param script_name: str :param arg_dict: dict [str: str] """ - arg_string = "".join(["-{0} {1} ".format(k, arg_dict[k]) - for k in arg_dict.keys()]) - command = "xvfb-run --auto-servernum --server-num=1 meshlabserver -s {0}/{1} {2}".\ - format(path_to_scripts, script_name, arg_string) + arg_string = "".join(["-{0} {1} ".format(k, arg_dict[k]) for k in arg_dict.keys()]) + command = ( + "xvfb-run --auto-servernum --server-num=1 meshlabserver -s {0}/{1} {2}".format( + path_to_scripts, script_name, arg_string + ) + ) p = subprocess.Popen(command, shell=True, stderr=subprocess.PIPE) p.wait() @@ -31,8 +33,9 @@ def _run_meshlab_script_on_dir_thread(args): script_name, path_block, out_dir, suffix, arg_dict = args for path in path_block: - out_path = "{}/{}{}.obj".format(out_dir, - "".join(os.path.basename(path).split(".")[:-1]), suffix) + out_path = "{}/{}{}.obj".format( + out_dir, "".join(os.path.basename(path).split(".")[:-1]), suffix + ) this_arg_dict = {"i": path, "o": out_path} this_arg_dict.update(arg_dict) @@ -40,8 +43,9 @@ def _run_meshlab_script_on_dir_thread(args): run_meshlab_script(script_name, this_arg_dict) -def run_meshlab_script_on_dir(script_name, in_dir, out_dir, suffix, arg_dict={}, - n_threads=1): +def run_meshlab_script_on_dir( + script_name, in_dir, out_dir, suffix, arg_dict={}, n_threads=1 +): paths = glob.glob(in_dir + "/*.obj") print(len(paths)) @@ -60,10 +64,8 @@ def run_meshlab_script_on_dir(script_name, in_dir, out_dir, suffix, arg_dict={}, multi_args.append([script_name, path_block, out_dir, suffix, arg_dict]) if n_threads == 1: - mu.multiprocess_func(_run_meshlab_script_on_dir_thread, - multi_args, debug=True, - verbose=True, n_threads=1) + for args in multi_args: + _run_meshlab_script_on_dir_thread(args) else: - mu.multisubprocess_func(_run_meshlab_script_on_dir_thread, - multi_args, n_threads=n_threads) - + with ProcessPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(_run_meshlab_script_on_dir_thread, multi_args)) diff --git a/requirements.in b/requirements.in index c6d241ff1..3ca5513c5 100644 --- a/requirements.in +++ b/requirements.in @@ -21,7 +21,6 @@ scikit-image # PyPI only: cloud-files>=6.0.0 cloud-volume>=12.0.0 -multiwrapper middle-auth-client>=3.11.0 zmesh>=1.7.0 fastremap>=1.14.0 diff --git a/requirements.txt b/requirements.txt index f5f8872df..33a82701c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -216,8 +216,6 @@ ml-dtypes==0.5.4 # via tensorstore multiprocess==0.70.19 # via pathos -multiwrapper==0.1.1 - # via -r requirements.in networkx==3.6.1 # via # -r requirements.in @@ -238,7 +236,6 @@ numpy==2.4.2 # messagingclient # microviewer # ml-dtypes - # multiwrapper # osteoid # pandas # scikit-image From 048ec916b317ba6185f19e9f77c53b587c908b6a Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Fri, 20 Mar 2026 21:47:41 +0000 Subject: [PATCH 16/16] use a common logger, cleanup old mess --- pychunkedgraph/__init__.py | 28 +++++++++++++++-- pychunkedgraph/app/__init__.py | 7 +++-- pychunkedgraph/graph/edges/stale.py | 9 ++++-- pychunkedgraph/graph/edits.py | 5 ++-- pychunkedgraph/graph/edits_sv.py | 27 +++++++++-------- pychunkedgraph/graph/locks.py | 6 ++-- pychunkedgraph/graph/operation.py | 4 +-- pychunkedgraph/ingest/__init__.py | 5 ++-- pychunkedgraph/ingest/cli.py | 5 ++-- pychunkedgraph/ingest/cli_upgrade.py | 7 +++-- pychunkedgraph/ingest/cluster.py | 19 +++++++----- pychunkedgraph/ingest/upgrade/atomic_layer.py | 14 +++++---- pychunkedgraph/ingest/upgrade/parent_layer.py | 30 +++++++++++-------- pychunkedgraph/ingest/utils.py | 11 ++++--- 14 files changed, 113 insertions(+), 64 deletions(-) diff --git a/pychunkedgraph/__init__.py b/pychunkedgraph/__init__.py index 28c0d26dc..0ade7b18a 100644 --- a/pychunkedgraph/__init__.py +++ b/pychunkedgraph/__init__.py @@ -9,6 +9,26 @@ "ignore", message="Schema id not specified", module="python_jsonschema_objects" ) +# Custom log level between INFO (20) and WARNING (30) +# Use logger.notice() for pychunkedgraph logs that should always show +# even when third-party INFO is suppressed +NOTICE = 25 +stdlib_logging.addLevelName(NOTICE, "NOTICE") + + +class PCGLogger(stdlib_logging.Logger): + def note(self, message, *args, **kwargs): + if self.isEnabledFor(NOTICE): + self._log(NOTICE, message, args, stacklevel=2, **kwargs) + + +stdlib_logging.setLoggerClass(PCGLogger) + + +def get_logger(name: str) -> PCGLogger: + return stdlib_logging.getLogger(name) # type: ignore[return-value] + + # Export logging levels for convenience DEBUG = stdlib_logging.DEBUG INFO = stdlib_logging.INFO @@ -36,7 +56,7 @@ def configure_logging(level=stdlib_logging.INFO, format_str=None, stream=None): pychunkedgraph.configure_logging(pychunkedgraph.DEBUG) # Enable DEBUG level """ if format_str is None: - format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + format_str = "%(asctime)s %(module)s:%(funcName)s:%(lineno)d %(message)s" if stream is None: stream = sys.stdout @@ -54,10 +74,12 @@ def configure_logging(level=stdlib_logging.INFO, format_str=None, stream=None): handler = stdlib_logging.StreamHandler(stream) handler.setLevel(level) - handler.setFormatter(stdlib_logging.Formatter(format_str)) + formatter = stdlib_logging.Formatter(format_str) + formatter.default_msec_format = "%s.%03d" + handler.setFormatter(formatter) logger.addHandler(handler) return logger -configure_logging() +configure_logging(level=NOTICE) diff --git a/pychunkedgraph/app/__init__.py b/pychunkedgraph/app/__init__.py index 042fa7ff1..7f5e307e8 100644 --- a/pychunkedgraph/app/__init__.py +++ b/pychunkedgraph/app/__init__.py @@ -14,6 +14,7 @@ from flask_cors import CORS from rq import Queue +from pychunkedgraph import NOTICE, configure_logging from pychunkedgraph.logging import jsonformatter from . import config @@ -99,9 +100,9 @@ def configure_app(app): app.logger.setLevel(app.config["LOGGING_LEVEL"]) app.logger.propagate = False - # Also configure root logger so logging.info() calls in library code are captured - logging.root.addHandler(handler) - logging.root.setLevel(logging.INFO) + # Ensure pychunkedgraph logger always works at NOTICE level + # regardless of app config or environment log level + configure_logging(level=NOTICE) if app.config["USE_REDIS_JOBS"]: app.redis = redis.Redis.from_url(app.config["REDIS_URL"]) diff --git a/pychunkedgraph/graph/edges/stale.py b/pychunkedgraph/graph/edges/stale.py index 17ded90d0..6ff3b8a12 100644 --- a/pychunkedgraph/graph/edges/stale.py +++ b/pychunkedgraph/graph/edges/stale.py @@ -3,8 +3,11 @@ """ import datetime -import logging from os import environ + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) from typing import Iterable import numpy as np @@ -428,7 +431,7 @@ def run(self): ) if _new_edges.size: break - logging.info(f"{_edge}, expanding search with padding {pad+1}.") + logger.note(f"{_edge}, expanding search with padding {pad+1}.") assert ( _new_edges.size ), f"No new edge found {_edge}; {edge_layer}, {self.parent_ts}" @@ -490,7 +493,7 @@ def get_latest_edges_wrapper( stale_edge_layers, parent_ts=parent_ts, ) - logging.debug(f"{stale_edges} -> {latest_edges}; {parent_ts}") + logger.debug(f"{stale_edges} -> {latest_edges}; {parent_ts}") _new_cx_edges.append(latest_edges) new_cx_edges_d[layer] = np.concatenate(_new_cx_edges) nodes.append(np.unique(new_cx_edges_d[layer])) diff --git a/pychunkedgraph/graph/edits.py b/pychunkedgraph/graph/edits.py index b29675661..779743740 100644 --- a/pychunkedgraph/graph/edits.py +++ b/pychunkedgraph/graph/edits.py @@ -1,6 +1,6 @@ # pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member -import datetime, logging, random +import datetime, random from typing import Dict from typing import List from typing import Tuple @@ -11,6 +11,7 @@ import fastremap import numpy as np +from pychunkedgraph import get_logger from pychunkedgraph.debug.profiler import HierarchicalProfiler, get_profiler from . import types @@ -25,7 +26,7 @@ from ..utils.general import in2d from ..debug.utils import sanity_check, sanity_check_single -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None): diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index 2b917c230..7e4ab93b5 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -3,8 +3,11 @@ """ from functools import reduce -import logging import multiprocessing as mp + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) from typing import Callable from datetime import datetime from collections import defaultdict, deque @@ -323,7 +326,7 @@ def _update_edges( def _add_new_edges(cg: ChunkedGraph, edges_tuple: tuple, time_stamp: datetime = None): edges_, affinites_, areas_ = edges_tuple - logging.info(f"new edges: {edges_.shape}") + logger.note(f"new edges: {edges_.shape}") nodes = fastremap.unique(edges_) chunks = cg.get_chunk_ids_from_node_ids(cg.get_parents(nodes)) @@ -348,7 +351,7 @@ def _add_new_edges(cg: ChunkedGraph, edges_tuple: tuple, time_stamp: datetime = time_stamp=time_stamp, ) ) - logging.info(f"writing {edges[mask].shape} edges to {chunk_id}") + logger.note(f"writing {edges[mask].shape} edges to {chunk_id}") return rows @@ -376,15 +379,13 @@ def split_supervoxel( bbe = np.clip((np.max(_coords, 0) + _padding).astype(int), vol_start, vol_end) chunk_min, chunk_max = bbs // chunk_size, np.ceil(bbe / chunk_size).astype(int) bbs, bbe = chunk_min * chunk_size, chunk_max * chunk_size - logging.info( - f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}; res {cg.meta.resolution}" - ) - logging.info(f"chunk and padding {chunk_size}; {_padding}") - logging.info(f"bbox and chunk min max {(bbs, bbe)}; {(chunk_min, chunk_max)}") + logger.note(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}; res {cg.meta.resolution}") + logger.note(f"chunk and padding {chunk_size}; {_padding}") + logger.note(f"bbox and chunk min max {(bbs, bbe)}; {(chunk_min, chunk_max)}") cut_supervoxels = _get_whole_sv(cg, sv_id, min_coord=chunk_min, max_coord=chunk_max) supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) - logging.info(f"whole sv {sv_id} -> {supervoxel_ids.tolist()}") + logger.note(f"whole sv {sv_id} -> {supervoxel_ids.tolist()}") # one voxel overlap for neighbors bbs_ = np.clip(bbs - 1, vol_start, vol_end) @@ -399,14 +400,14 @@ def split_supervoxel( cg.meta.resolution, verbose=verbose, ) - logging.info(f"split_result: {split_result.shape}") + logger.note(f"split_result: {split_result.shape}") chunks_bbox_map = chunks_overlapping_bbox(bbs, bbe, cg.meta.graph_config.CHUNK_SIZE) tasks = [ (cg.graph_id, *item, seg[voxel_overlap_crop], split_result, bbs) for item in chunks_bbox_map.items() ] - logging.info(f"tasks count: {len(tasks)}") + logger.note(f"tasks count: {len(tasks)}") with mp.Pool() as pool: results = [*tqdm(pool.imap_unordered(_update_chunk, tasks), total=len(tasks))] seg_cropped = seg[voxel_overlap_crop].copy() @@ -418,7 +419,7 @@ def split_supervoxel( roots = cg.get_roots(sv_ids) sv_root_map = dict(zip(sv_ids, roots)) root = sv_root_map[sv_id] - logging.info(f"{sv_id} -> {root}") + logger.note(f"{sv_id} -> {root}") root_mask = fastremap.remap(seg, sv_root_map, in_place=False) == root seg[~root_mask] = 0 @@ -437,7 +438,7 @@ def split_supervoxel( rows0 = copy_parents_and_add_lineage(cg, operation_id, old_new_map) rows1 = _add_new_edges(cg, edges_tuple, time_stamp=time_stamp) rows = rows0 + rows1 - logging.info(f"{operation_id}: writing {len(rows)} new rows") + logger.note(f"{operation_id}: writing {len(rows)} new rows") cg.meta.ws_ocdbt[slices] = new_seg[..., np.newaxis] cg.client.write(rows) diff --git a/pychunkedgraph/graph/locks.py b/pychunkedgraph/graph/locks.py index f7406922f..47a63dacf 100644 --- a/pychunkedgraph/graph/locks.py +++ b/pychunkedgraph/graph/locks.py @@ -1,5 +1,4 @@ from concurrent.futures import ThreadPoolExecutor, as_completed -import logging from typing import Union from typing import Sequence from collections import defaultdict @@ -7,11 +6,14 @@ import networkx as nx import numpy as np +from pychunkedgraph import get_logger + from . import exceptions from .types import empty_1d from .lineage import lineage_graph -logger = logging.getLogger(__name__) +logger = get_logger(__name__) + class RootLock: """Attempts to lock the requested root IDs using a unique operation ID. diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 0d91e3990..4c85bd463 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -1,6 +1,5 @@ # pylint: disable=invalid-name, missing-docstring, too-many-lines, protected-access, broad-exception-raised -import logging from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime @@ -16,8 +15,9 @@ from functools import reduce import numpy as np +from pychunkedgraph import get_logger -logger = logging.getLogger(__name__) +logger = get_logger(__name__) from . import locks from . import edits diff --git a/pychunkedgraph/ingest/__init__.py b/pychunkedgraph/ingest/__init__.py index 55c10ca5f..482dfbb5f 100644 --- a/pychunkedgraph/ingest/__init__.py +++ b/pychunkedgraph/ingest/__init__.py @@ -1,7 +1,8 @@ -import logging from collections import namedtuple -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +from pychunkedgraph import configure_logging, NOTICE + +configure_logging(level=NOTICE) _ingestconfig_fields = ( "AGGLOMERATION", diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index ca958c354..ba63e15f8 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -4,9 +4,10 @@ cli for running ingest """ -import logging import os +from pychunkedgraph import configure_logging, DEBUG + import click import yaml from flask.cli import AppGroup @@ -77,7 +78,7 @@ def ingest_graph( config = yaml.safe_load(stream) if test: - logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.DEBUG) + configure_logging(level=DEBUG) meta, ingest_config, client_info = bootstrap(graph_id, config, raw, test) cg = ChunkedGraph(meta=meta, client_info=client_info) diff --git a/pychunkedgraph/ingest/cli_upgrade.py b/pychunkedgraph/ingest/cli_upgrade.py index d7b7a56dd..4b5ed12c7 100644 --- a/pychunkedgraph/ingest/cli_upgrade.py +++ b/pychunkedgraph/ingest/cli_upgrade.py @@ -4,9 +4,12 @@ cli for running upgrade """ -import logging from time import sleep +from pychunkedgraph import get_logger + +logger = get_logger(__name__) + import click import tensorstore as ts from flask.cli import AppGroup @@ -89,7 +92,7 @@ def upgrade_graph(graph_id: str, test: bool, ocdbt: bool): enqueue_l2_tasks(imanager, fn) if ocdbt: - logging.info("All tasks queued. Keep this alive for ocdbt coordinator server.") + logger.note("All tasks queued. Keep this alive for ocdbt coordinator server.") while True: sleep(60) diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 6233c9d46..2736d6819 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -4,8 +4,11 @@ Ingest / create chunkedgraph with workers on a cluster. """ -import logging from os import environ + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) from time import sleep from typing import Callable, Dict, Iterable, Tuple, Sequence @@ -55,7 +58,7 @@ def _post_task_completion( chunk_str += f"_{split}" # mark chunk as completed - "c" imanager.redis.sadd(f"{layer}c", chunk_str) - logging.info(f"{chunk_str} marked as complete") + logger.note(f"{chunk_str} marked as complete") def create_parent_chunk( @@ -139,9 +142,9 @@ def create_atomic_chunk(coords: Sequence[int]): add_atomic_chunk(imanager.cg, coords, chunk_edges_active, isolated=isolated_ids) for k, v in chunk_edges_all.items(): - logging.debug(f"{k}: {len(v)}") + logger.debug(f"{k}: {len(v)}") for k, v in chunk_edges_active.items(): - logging.debug(f"active_{k}: {len(v)}") + logger.debug(f"active_{k}: {len(v)}") if imanager.ocdbt_seg: src, dst = get_seg_source_and_destination_ocdbt( @@ -196,7 +199,7 @@ def convert_to_ocdbt(coords: Sequence[int]): port = imanager.redis.get("OCDBT_COORDINATOR_PORT").decode() environ["OCDBT_COORDINATOR_HOST"] = host environ["OCDBT_COORDINATOR_PORT"] = port - logging.info(f"OCDBT Coordinator address {host}:{port}") + logger.note(f"OCDBT Coordinator address {host}:{port}") put_edges( f"{imanager.cg.meta.data_source.EDGES}/ocdbt", @@ -224,7 +227,7 @@ def _queue_tasks(imanager: IngestionManager, chunk_fn: Callable, coords: Iterabl _coords = get_chunks_not_done(imanager, 2, batch) # buffer for optimal use of redis memory while len(q) > max_queue_size: - logging.info( + logger.note( f"Queue has {len(q)} items (limit {max_queue_size}), waiting..." ) sleep(10) @@ -244,7 +247,7 @@ def _queue_tasks(imanager: IngestionManager, chunk_fn: Callable, coords: Iterabl ) ) q.enqueue_many(job_datas) - logging.info(f"Queued {len(job_datas)} chunks.") + logger.note(f"Queued {len(job_datas)} chunks.") def enqueue_l2_tasks(imanager: IngestionManager, chunk_fn: Callable): @@ -257,5 +260,5 @@ def enqueue_l2_tasks(imanager: IngestionManager, chunk_fn: Callable): atomic_chunk_bounds = imanager.cg_meta.layer_chunk_bounds[2] chunk_coords = randomize_grid_points(*atomic_chunk_bounds) chunk_count = imanager.cg_meta.layer_chunk_counts[0] - logging.info(f"Chunk count: {chunk_count}, queuing...") + logger.note(f"Chunk count: {chunk_count}, queuing...") _queue_tasks(imanager, chunk_fn, chunk_coords) diff --git a/pychunkedgraph/ingest/upgrade/atomic_layer.py b/pychunkedgraph/ingest/upgrade/atomic_layer.py index c767ca124..69463d7f6 100644 --- a/pychunkedgraph/ingest/upgrade/atomic_layer.py +++ b/pychunkedgraph/ingest/upgrade/atomic_layer.py @@ -2,7 +2,11 @@ from collections import defaultdict from datetime import datetime, timedelta, timezone -import logging, time, os +import time, os + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) from copy import copy import fastremap @@ -79,7 +83,7 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list: for partner, parents in zip(all_partners, all_parents): for parent, ts in parents: parents_ts_map[partner][ts] = parent - logging.info(f"update_nodes init {len(nodes)}: {time.time() - start}") + logger.note(f"update_nodes init {len(nodes)}: {time.time() - start}") rows = [] skipped = [] @@ -142,7 +146,7 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int]): clean_task = os.environ.get("CLEAN_CHUNKS", "false") == "clean" if clean_task: - logging.info(f"found {len(corrupt_nodes)} corrupt nodes {corrupt_nodes[:3]}...") + logger.note(f"found {len(corrupt_nodes)} corrupt nodes {corrupt_nodes[:3]}...") fix_corrupt_nodes(cg, corrupt_nodes, CHILDREN) return @@ -156,8 +160,8 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int]): if len(nodes) == 0: return - logging.info(f"processing {len(nodes)} nodes.") + logger.note(f"processing {len(nodes)} nodes.") assert len(CHILDREN) > 0, (nodes, CHILDREN) rows = update_nodes(cg, nodes, nodes_ts) cg.client.write(rows) - logging.info(f"mutations: {len(rows)}, time: {time.time() - start}") + logger.note(f"mutations: {len(rows)}, time: {time.time() - start}") diff --git a/pychunkedgraph/ingest/upgrade/parent_layer.py b/pychunkedgraph/ingest/upgrade/parent_layer.py index 773fc9ed0..fd46917e2 100644 --- a/pychunkedgraph/ingest/upgrade/parent_layer.py +++ b/pychunkedgraph/ingest/upgrade/parent_layer.py @@ -1,7 +1,11 @@ # pylint: disable=invalid-name, missing-docstring, c-extension-no-member from math import ceil -import bisect, logging, random, time, os, gc +import bisect, random, time, os, gc + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) import multiprocessing as mp from collections import defaultdict from datetime import datetime, timezone @@ -54,7 +58,7 @@ def _get_cx_edges_at_timestamp(node, response, ts): try: result[key.index].append(cells[idx].value) except IndexError as e: - logging.error(f"{k}, {idx}, {len(cells)}, {asc_ts}") + logger.error(f"{k}, {idx}, {len(cells)}, {asc_ts}") raise IndexError from e for layer, edges in result.items(): result[layer] = np.concatenate(edges) @@ -84,7 +88,7 @@ def _populate_cx_edges_with_timestamps( response = cg.client.read_nodes(node_ids=all_children, properties=attrs) timestamps_d = get_parent_timestamps(cg, nodes) end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, CHILDREN, layer=layer) - logging.info(f"_populate_cx_edges_with_timestamps init: {time.time() - start}") + logger.note(f"_populate_cx_edges_with_timestamps init: {time.time() - start}") start = time.time() partners_map = {} @@ -97,7 +101,7 @@ def _populate_cx_edges_with_timestamps( partners = np.unique(np.concatenate([*partners_map.values()])) partner_parent_ts_d = get_parent_timestamps(cg, partners) - logging.info(f"get partners timestamps init: {time.time() - start}") + logger.note(f"get partners timestamps init: {time.time() - start}") rows = [] for node, node_ts, node_end_ts in zip(nodes, nodes_ts, end_timestamps): @@ -180,7 +184,7 @@ def _update_cross_edges_helper(args): tasks.append((cg, layer, node, node_ts)) if clean_task: - logging.info(f"found {len(corrupt_nodes)} corrupt nodes {corrupt_nodes[:3]}...") + logger.note(f"found {len(corrupt_nodes)} corrupt nodes {corrupt_nodes[:3]}...") fix_corrupt_nodes(cg, corrupt_nodes, CHILDREN) return @@ -224,31 +228,31 @@ def update_chunk( nodes = _get_split_nodes(cg, chunk_id, split, splits) _populate_nodes_and_children(cg, chunk_id, nodes=nodes) - logging.info(f"_populate_nodes_and_children: {time.time() - start}") + logger.note(f"_populate_nodes_and_children: {time.time() - start}") nodes = list(CHILDREN.keys()) if len(nodes) == 0: return - logging.info(f"processing {len(nodes)} nodes.") + logger.note(f"processing {len(nodes)} nodes.") random.shuffle(nodes) start = time.time() nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True) - logging.info(f"get_node_timestamps: {time.time() - start}") + logger.note(f"get_node_timestamps: {time.time() - start}") start = time.time() _populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts) - logging.info(f"_populate_cx_edges_with_timestamps: {time.time() - start}") + logger.note(f"_populate_cx_edges_with_timestamps: {time.time() - start}") if debug: rows = [] stale.PARENTS_CACHE = LRUCache(PARENT_CACHE_LIMIT) stale.CHILDREN_CACHE = LRUCache(1 * 1024) - logging.info(f"processing {len(nodes)} nodes with 1 worker.") + logger.note(f"processing {len(nodes)} nodes with 1 worker.") for node, node_ts in zip(nodes, nodes_ts): rows.extend(update_cross_edges(cg, layer, node, node_ts)) stale.PARENTS_CACHE.clear() stale.CHILDREN_CACHE.clear() - logging.info(f"total elaspsed time: {time.time() - start}") + logger.note(f"total elaspsed time: {time.time() - start}") return task_size = int(os.environ.get("TASK_SIZE", 1)) @@ -263,7 +267,7 @@ def update_chunk( process_multiplier = int(os.environ.get("PROCESS_MULTIPLIER", 5)) processes = min(mp.cpu_count() * process_multiplier, len(tasks)) - logging.info(f"processing {len(nodes)} nodes with {processes} workers.") + logger.note(f"processing {len(nodes)} nodes with {processes} workers.") with mp.Pool(processes) as pool: _ = list( tqdm( @@ -271,4 +275,4 @@ def update_chunk( total=len(tasks), ) ) - logging.info(f"total elaspsed time: {time.time() - start}") + logger.note(f"total elaspsed time: {time.time() - start}") diff --git a/pychunkedgraph/ingest/utils.py b/pychunkedgraph/ingest/utils.py index c41a41a56..d69756104 100644 --- a/pychunkedgraph/ingest/utils.py +++ b/pychunkedgraph/ingest/utils.py @@ -1,7 +1,10 @@ # pylint: disable=invalid-name, missing-docstring -import logging import functools + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) import math, random, sys from os import environ from time import sleep @@ -99,7 +102,7 @@ def start_ocdbt_server(imanager: IngestionManager, server: Any): imanager.redis.set("OCDBT_COORDINATOR_PORT", str(server.port)) ocdbt_host = environ.get("MY_POD_IP", "localhost") imanager.redis.set("OCDBT_COORDINATOR_HOST", ocdbt_host) - logging.info(f"OCDBT Coordinator address {ocdbt_host}:{server.port}") + logger.note(f"OCDBT Coordinator address {ocdbt_host}:{server.port}") def randomize_grid_points(X: int, Y: int, Z: int) -> Generator[int, int, int]: @@ -225,7 +228,7 @@ def queue_layer_helper( _coords = get_chunks_not_done(imanager, parent_layer, batch, splits=splits) # buffer for optimal use of redis memory while len(q) > max_queue_size: - logging.info( + logger.note( f"Queue has {len(q)} items (limit {max_queue_size}), waiting..." ) sleep(10) @@ -261,7 +264,7 @@ def queue_layer_helper( ) ) q.enqueue_many(job_datas) - logging.info(f"Queued {len(job_datas)} chunks.") + logger.note(f"Queued {len(job_datas)} chunks.") def job_type_guard(job_type: str):