Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 78 additions & 40 deletions meshparty/skeleton_quality/multicut.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import networkx as nx
from meshparty.meshwork import Meshwork
from meshparty.trimesh_io import Mesh
import pandas as pd
import numpy as np
import pandas as pd
from scipy import sparse

from meshparty.meshwork import Meshwork
from meshparty.trimesh_io import Mesh


def _build_multicut_graph(nrn):
G = nx.from_scipy_sparse_matrix(nrn.mesh.csgraph)
if nx.__version__ < 3:
G = nx.from_scipy_sparse_matrix(nrn.mesh.csgraph)
else:
G = nx.from_scipy_sparse_array(nrn.mesh.csgraph)

G.add_node('source')
G.add_node('target')
G.add_node("source")
G.add_node("target")

source_edges = [('source', ii, np.inf)
for ii in nrn.anno['st_df'].df.query('type == "s"').mesh_index.values]
target_edges = [('target', ii, np.inf)
for ii in nrn.anno['st_df'].df.query('type == "t"').mesh_index.values]
source_edges = [
("source", ii, np.inf)
for ii in nrn.anno["st_df"].df.query('type == "s"').mesh_index.values
]
target_edges = [
("target", ii, np.inf)
for ii in nrn.anno["st_df"].df.query('type == "t"').mesh_index.values
]

G.add_weighted_edges_from(source_edges)
G.add_weighted_edges_from(target_edges)
Expand All @@ -24,12 +32,14 @@ def _build_multicut_graph(nrn):


def _multicut_partitions(G, nrn):
_, partition = nx.minimum_cut(G, 'source', 'target', capacity='weight')
_, partition = nx.minimum_cut(G, "source", "target", capacity="weight")

part0 = list(partition[0].difference({'source', 'target'}))
part1 = list(partition[1].difference({'source', 'target'}))
part0 = list(partition[0].difference({"source", "target"}))
part1 = list(partition[1].difference({"source", "target"}))

return nrn.MeshIndex(part0).to_mesh_mask_base, nrn.MeshIndex(part1).to_mesh_mask_base
return nrn.MeshIndex(part0).to_mesh_mask_base, nrn.MeshIndex(
part1
).to_mesh_mask_base


def _build_nrn_with_st_annos(mesh, source_points, target_points):
Expand All @@ -38,30 +48,39 @@ def _build_nrn_with_st_annos(mesh, source_points, target_points):
if isinstance(target_points, np.ndarray):
target_points = target_points.tolist()
nrn = Meshwork(mesh, voxel_resolution=[1, 1, 1])
source_df = pd.DataFrame(data={'pt_position': source_points})
source_df['type'] = 's'
target_df = pd.DataFrame(data={'pt_position': target_points})
target_df['type'] = 't'
source_df = pd.DataFrame(data={"pt_position": source_points})
source_df["type"] = "s"
target_df = pd.DataFrame(data={"pt_position": target_points})
target_df["type"] = "t"
st_df = source_df.append(target_df, ignore_index=True)
st_df['pt_position'] = np.vstack(st_df['pt_position'].values).tolist()
st_df["pt_position"] = np.vstack(st_df["pt_position"].values).tolist()

nrn.add_annotations('st_df', st_df, point_column='pt_position', anchored=True, overwrite=True)
nrn.add_annotations(
"st_df", st_df, point_column="pt_position", anchored=True, overwrite=True
)
return nrn


def _build_local_mask(nrn, initial_window):
ds = sparse.csgraph.dijkstra(
nrn.mesh.csgraph, indices=nrn.anno['st_df'].mesh_index, limit=initial_window)
d_sq = ds[:, nrn.anno['st_df'].mesh_index]
nrn.mesh.csgraph, indices=nrn.anno["st_df"].mesh_index, limit=initial_window
)
d_sq = ds[:, nrn.anno["st_df"].mesh_index]
if np.any(np.isinf(d_sq.ravel())):
raise ValueError(
"Initial window is too low (default: 10000) or points are in different components")
"Initial window is too low (default: 10000) or points are in different components"
)

# Centers mask on the point with the lowest mean distance to other points
ctr_ind = np.argmin(np.mean(d_sq, axis=0))
ctr_pt = nrn.anno['st_df'].mesh_index[ctr_ind]
local_mask = np.invert(np.isinf(sparse.csgraph.dijkstra(
nrn.mesh.csgraph, indices=ctr_pt, limit=np.max(d_sq[ctr_ind])+1)))
ctr_pt = nrn.anno["st_df"].mesh_index[ctr_ind]
local_mask = np.invert(
np.isinf(
sparse.csgraph.dijkstra(
nrn.mesh.csgraph, indices=ctr_pt, limit=np.max(d_sq[ctr_ind]) + 1
)
)
)
return local_mask


Expand All @@ -78,51 +97,66 @@ def _faces_to_keep(p1mask, p2mask, nrn):
return good_faces


def _add_expected_edges(G, new_mesh, p1mask, p2mask, local_network_mask, test_split=True):
def _add_expected_edges(
G, new_mesh, p1mask, p2mask, local_network_mask, test_split=True
):
"Adds edges that were not included in the faces graph"
G.remove_node('source')
G.remove_node('target')
G.remove_node("source")
G.remove_node("target")

new_mesh_filt = new_mesh.apply_mask(local_network_mask)
p1s = new_mesh_filt.filter_unmasked_boolean(p1mask)
p2s = new_mesh_filt.filter_unmasked_boolean(p2mask)

# Make matrix without cross-partition edges
Gorig = nx.to_scipy_sparse_matrix(G)
if nx.__version__ < 3:
Gorig = nx.to_scipy_sparse_matrix(G)
else:
Gorig = nx.to_scipy_sparse_array(G)
ii, jj, dd = sparse.find(Gorig)
keep11 = p1s[ii] & p1s[jj]
keep22 = p2s[ii] & p2s[jj]
keep_all = keep11 | keep22

GsplitB = sparse.csr_matrix((dd[keep_all], (ii[keep_all], jj[keep_all]))).toarray() > 0
GsplitB = (
sparse.csr_matrix((dd[keep_all], (ii[keep_all], jj[keep_all]))).toarray() > 0
)

Gnew = new_mesh_filt.csgraph.toarray()
GnewB = Gnew > 0

# Places where edge in expected Gmat but not in new mesh
link_edges_to_add_rough = np.vstack(np.where(np.logical_and(GsplitB == True, GnewB == False))).T
link_edges_to_add_rough = np.vstack(
np.where(np.logical_and(GsplitB == True, GnewB == False))
).T
if len(link_edges_to_add_rough) > 0:
link_edges_to_add = np.unique(
[tuple(x) for x in np.sort(link_edges_to_add_rough, axis=1)], axis=0)
[tuple(x) for x in np.sort(link_edges_to_add_rough, axis=1)], axis=0
)

link_edges_unmasked = new_mesh_filt.map_indices_to_unmasked(link_edges_to_add)
new_mesh.link_edges = np.vstack(
(new_mesh.link_edges, new_mesh.filter_unmasked_indices(link_edges_unmasked)))
(new_mesh.link_edges, new_mesh.filter_unmasked_indices(link_edges_unmasked))
)

if test_split:
if len(link_edges_to_add_rough) > 0:
new_mesh_filt.link_edges = np.vstack((new_mesh_filt.link_edges, link_edges_to_add))
new_mesh_filt.link_edges = np.vstack(
(new_mesh_filt.link_edges, link_edges_to_add)
)

ncomp = sparse.csgraph.connected_components(new_mesh_filt.csgraph)[0]
if ncomp > 2:
print('Warning: more than 2 local components after split')
print("Warning: more than 2 local components after split")
if ncomp == 1:
print('Warning: Only 1 local component after split')
print("Warning: Only 1 local component after split")

return new_mesh


def mesh_multicut(mesh, source_points, target_points, initial_window=10000, return_masks=False):
def mesh_multicut(
mesh, source_points, target_points, initial_window=10000, return_masks=False
):
"""Use multi-point source/target split to cut a minimal set of faces from a mesh.
Warns if the split produces more than 2 graph components in a local cutout, although
the end result may still be suitable.
Expand Down Expand Up @@ -166,8 +200,12 @@ def mesh_multicut(mesh, source_points, target_points, initial_window=10000, retu

keep_faces = _faces_to_keep(p1mask, p2mask, nrn)

new_mesh = Mesh(vertices=nrn.mesh.vertices,
faces=nrn.mesh.faces[keep_faces], node_mask=nrn.mesh_mask, link_edges=nrn.mesh.link_edges)
new_mesh = Mesh(
vertices=nrn.mesh.vertices,
faces=nrn.mesh.faces[keep_faces],
node_mask=nrn.mesh_mask,
link_edges=nrn.mesh.link_edges,
)

new_mesh = _add_expected_edges(G, new_mesh, p1mask, p2mask, local_network_mask)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ h5py
numpy
scipy>=1.3.0
scikit-learn
networkx<3
networkx
multiwrapper
cloud-volume>=1.16.0
trimesh>=3.0.14
Expand Down