Open bentaculum opened 5 months ago
Thanks for opening the conversation @bentaculum!
rustworkx is NOT a drop-in replacement for networkx, and not even a subset of networkx. We currently use networkx-only convenience functions like from_pandas_edgelist() in graph construction. This there requires dealing with the different API.
I think this is ok. The convenience of networkx is fine but for this package I'm more concerned about speed than I am about the convenenience of the internals. I do think from the perspective of the user we should be taking networkx
objects and returning networkx
objects, casting/converting to rustworkx
internally only (once on load, once on return). And of course, for use with CTC data, it should be completely transparent to the user via our loader.
It is not certain that there will be speedups for traccuracy. The rustworkx benchmarks I find are promising but limited to certain high-level algorithms.
I suppose we will need to check! Have you profiled the metrics computation by the way? I'd be interested to see a profile - maybe there's still simple things we can do on our end.
rustworkx requires node ids to be int. We currently use strings of format segmentation-ID_time.
I have no strong preference for keeping the string IDs (they kinda annoy me actually because all my other graphs have int IDs), but I undersand the initial driver of wanting them to be meaningful for a user.
I suppose we will need to check! Have you profiled the metrics computation by the way? I'd be interested to see a profile - maybe there's still simple things we can do on our end.
Fresh profiling with ipython's %lprun
. Somehow faster than yesterday ...
Most of the time is spent on getting the label_to_id
mappings from the graphs.
from traccuracy import run_metrics
from traccuracy.loaders import load_ctc_data
from traccuracy.metrics import CTCMetrics
from traccuracy.matchers import CTCMatcher
gt_data = load_ctc_data(
'/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA',
'/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA/man_track.txt'
)
pred_data = load_ctc_data(
'/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA',
'/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA/man_track.txt'
)
%time ctc_results = run_metrics(gt_data=gt_data,pred_data=pred_data,matcher=CTCMatcher(),metrics=[CTCMetrics()],)
%lprun -f CTCMatcher._compute_mapping -T profile.txt ctc_results = run_metrics(gt_data=gt_data,pred_data=pred_data,matcher=CTCMatcher(),metrics=[CTCMetrics()],)
Matching frames: 100%|█████████████████████████████████████████████████| 300/300 [00:14<00:00, 20.52it/s]
INFO:traccuracy.matchers._base:Matched 71403 out of 71403 ground truth nodes.
INFO:traccuracy.matchers._base:Matched 71403 out of 71403 predicted nodes.
Evaluating nodes: 100%|████████████████████████████████████████| 71403/71403 [00:00<00:00, 594404.37it/s]
Evaluating FP edges: 100%|█████████████████████████████████████| 71201/71201 [00:00<00:00, 919657.80it/s]
Evaluating FN edges: 100%|████████████████████████████████████| 71201/71201 [00:00<00:00, 1002065.74it/s]
CPU times: user 16.5 s, sys: 895 ms, total: 17.4 s
Wall time: 17.6 s
Timer unit: 1e-09 s
Total time: 31.7674 s
File: /Users/gallusse/code/traccuracy/src/traccuracy/matchers/_ctc.py
Function: _compute_mapping at line 29
Line # Hits Time Per Hit % Time Line Contents
==============================================================
29 def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
30 """Run ctc matching
31
32 Args:
33 gt_graph (TrackingGraph): Tracking graph object for the gt
34 pred_graph (TrackingGraph): Tracking graph object for the pred
35
36 Returns:
37 traccuracy.matchers.Matched: Matched data object containing the CTC mapping
38
39 Raises:
40 ValueError: if GT and pred segmentations are None or are not the same shape
41 """
42 1 1000.0 1000.0 0.0 gt = gt_graph
43 1 0.0 0.0 0.0 pred = pred_graph
44 1 1000.0 1000.0 0.0 gt_label_key = gt_graph.label_key
45 1 0.0 0.0 0.0 pred_label_key = pred_graph.label_key
46 1 1000.0 1000.0 0.0 G_gt, mask_gt = gt, gt.segmentation
47 1 0.0 0.0 0.0 G_pred, mask_pred = pred, pred.segmentation
48
49 1 0.0 0.0 0.0 if mask_gt is None or mask_pred is None:
50 raise ValueError("Segmentation is None, cannot perform matching")
51
52 1 4000.0 4000.0 0.0 if mask_gt.shape != mask_pred.shape:
53 raise ValueError("Segmentation shapes must match between gt and pred")
54
55 1 0.0 0.0 0.0 mapping = []
56 # Get overlaps for each frame
57 302 41774000.0 138324.5 0.1 for i, t in enumerate(
58 2 291000.0 145500.0 0.0 tqdm(
59 1 2000.0 2000.0 0.0 range(gt.start_frame, gt.end_frame),
60 1 0.0 0.0 0.0 desc="Matching frames",
61 )
62 ):
63 300 250000.0 833.3 0.0 gt_frame = mask_gt[i]
64 300 91000.0 303.3 0.0 pred_frame = mask_pred[i]
65 300 327000.0 1090.0 0.0 gt_frame_nodes = gt.nodes_by_frame[t]
66 300 244000.0 813.3 0.0 pred_frame_nodes = pred.nodes_by_frame[t]
67
68 # get the labels for this frame
69 600 2552632000.0 4e+06 8.0 gt_labels = dict(
70 600 424000.0 706.7 0.0 filter(
71 300 147000.0 490.0 0.0 lambda item: item[0] in gt_frame_nodes,
72 300 1e+10 4e+07 34.3 nx.get_node_attributes(G_gt.graph, gt_label_key).items(),
73 )
74 )
75 300 9992000.0 33306.7 0.0 gt_label_to_id = {v: k for k, v in gt_labels.items()}
76
77 600 2465360000.0 4e+06 7.8 pred_labels = dict(
78 600 470000.0 783.3 0.0 filter(
79 300 182000.0 606.7 0.0 lambda item: item[0] in pred_frame_nodes,
80 300 1e+10 4e+07 34.3 nx.get_node_attributes(G_pred.graph, pred_label_key).items(),
81 )
82 )
83 300 9263000.0 30876.7 0.0 pred_label_to_id = {v: k for k, v in pred_labels.items()}
84
85 300 48000.0 160.0 0.0 (
86 300 84000.0 280.0 0.0 overlapping_gt_labels,
87 300 65000.0 216.7 0.0 overlapping_pred_labels,
88 300 586000.0 1953.3 0.0 intersection,
89 300 4715872000.0 2e+07 14.8 ) = get_labels_with_overlap(gt_frame, pred_frame)
90
91 72093 6224000.0 86.3 0.0 for i in range(len(overlapping_gt_labels)):
92 71793 9997000.0 139.2 0.0 gt_label = overlapping_gt_labels[i]
93 71793 8246000.0 114.9 0.0 pred_label = overlapping_pred_labels[i]
94 # CTC metrics only match comp IDs to a single GT ID if there is majority overlap
95 71793 9167000.0 127.7 0.0 if intersection[i] > 0.5:
96 142806 16758000.0 117.3 0.1 mapping.append(
97 71403 116304000.0 1628.8 0.4 (gt_label_to_id[gt_label], pred_label_to_id[pred_label])
98 )
99
100 1 2000.0 2000.0 0.0 return Matched(gt_graph, pred_graph, mapping)
I coded up a simple speed-up without swapping the backend for now #148.
I do think from the perspective of the user we should be taking networkx objects and returning networkx objects, casting/converting to rustworkx internally only (once on load, once on return). And of course, for use with CTC data, it should be completely transparent to the user via our loader.
I like this, agreed.
Description
Hi all,
I would like to get started with the once-upon-a-time discussed swapping of the backend from
networkx
torustworkx
https://www.rustworkx.org/index.html. The reason being that for example for basic 2d datasets from the cell tracking challenge (PhC-C2DL-PSC, ~70k nodes), calculating CTCMetrics takes > 1 minute right now.This seems to be, amongst other things, due to some basic attribute getting from the graph in the matching of GT and predicted nodes.
I anticipate the following challenges:
from_pandas_edgelist()
in graph construction. This there requires dealing with the different API.int
. We currently use strings of formatsegmentation-ID_time
.Has anyone else already given this some more thoughts and would like to pitch in, with either comments or coding together?
Cheers
Ben
Topics
What types of changes are you suggesting? Put an x in the boxes that apply.
Which topics does your change affect? Put an x in the boxes that apply.
TrackingGraph
,run_metrics
,cli
, etc.)Priority
Are you interested in contributing?