Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF/WIP: Parallel subnet solving #507

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions trackpy/linking/linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import warnings
import logging
import itertools, functools
from concurrent import futures
from multiprocessing import cpu_count

import numpy as np

Expand Down Expand Up @@ -499,12 +501,19 @@ def next_level(self, coords, t, extra_data=None):

def assign_links(self):
spl, dpl = [], []
for source_set, dest_set in self.subnets:

def link_subnet(subnet_sets):
source_set, dest_set = subnet_sets
for sp in source_set:
sp.forward_cands.sort(key=lambda x: x[1])

sn_spl, sn_dpl = self.subnet_linker(source_set, dest_set,
self.search_range)
return self.subnet_linker(source_set, dest_set,
self.search_range)

workers = futures.ThreadPoolExecutor(max_workers=cpu_count())
futs = [workers.submit(link_subnet, sub) for sub in self.subnets]
for fut in futures.as_completed(futs):
sn_spl, sn_dpl = fut.result()
spl.extend(sn_spl)
dpl.extend(sn_dpl)

Expand Down
2 changes: 1 addition & 1 deletion trackpy/linking/subnetlinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def numba_link(s_sn, dest_size, search_range, max_size=30, diag=False):
dest_results = [dcands[i] if i >= 0 else None for i in best_assignments]
return source_results, dest_results

@try_numba_jit(nopython=True)
@try_numba_jit(nopython=True, nogil=True)
def _numba_subnet_norecur(ncands, candsarray, dists2array, cur_assignments,
cur_sums, tmp_assignments, best_assignments):
"""Find the optimal track assignments for a subnetwork, without recursion.
Expand Down