sylvainprigent / stracking

python library for particles tracking in 2D+t and 3D+t scientific images
Other
13 stars 2 forks source link

Accelerating splink using graph-tool for bellman-ford search #6

Open orena1 opened 1 year ago

orena1 commented 1 year ago

The graph-tool library appears to be much faster than csgraph.bellman_ford_search for running the Bellman-Ford algorithm. However, installing graph-tool can be a bit challenging, and it is not well supported on Windows. For this reason, I have added graph-tool as an optional dependency rather than a required one. In terms of speed, the following code demonstrates that graph-tool is approximately 5 times faster than the default implementation. In 3D, I have found the speed difference to be even greater (see here)

Speed-wise:

from stracking.linkers import SPLinker, EuclideanCost
from stracking.containers import SParticles
import numpy as np
import time

detections = np.array([[0., 53., 12.],
                   [0., 93., 11.],
                   [0., 13., 10.],
                   [1., 53., 26.],
                   [1., 93., 26.],
                   [1., 13., 26.],
                   [2., 13., 41.],
                   [2., 93., 41.],
                   [2., 53., 41.],
                   [3., 93., 56.],
                   [3., 13., 55.],
                   [3., 54., 56.],
                   [4., 53., 71.],
                   [4., 94., 71.],
                   [4., 13., 71.]])

tm = []
for i in range(2950):
    cp = detections[detections[:,0] == i%3]
    cp[:,0] = i
    tm.append(cp)
detections = np.vstack(tm)

# Default run
particles_all = SParticles(data=detections)
euclidean_cost = EuclideanCost(max_cost=150)
my_tracker = SPLinker(cost=euclidean_cost, gap=3)

start_time = time.time()
tracks = my_tracker.run(particles_all)
print(f'Default run: {time.time() - start_time}')

# graph-tool run
particles_all = SParticles(data=detections)
euclidean_cost = EuclideanCost(max_cost=150)
my_tracker = SPLinker(cost=euclidean_cost, gap=3)

start_time = time.time()
tracks_graph_tool = my_tracker.run(particles_all, graph_tool=True)
print(f'graph-tool: {time.time() - start_time}')

assert np.all(tracks_graph_tool.data == tracks.data)

Default run: 16.17208695411682 graph-tool: 3.133636713027954

I've also added a test file. I hope it helps.