Janelia-Trackathon-2023 / traccuracy

Utilities for computing common accuracy metrics on cell tracking challenge solutions with ground truth
https://traccuracy.readthedocs.io
Other
25 stars 7 forks source link

Swap out networkx for rustworkx #147

Open bentaculum opened 5 months ago

bentaculum commented 5 months ago

Description

Hi all,

I would like to get started with the once-upon-a-time discussed swapping of the backend from networkx to rustworkx https://www.rustworkx.org/index.html. The reason being that for example for basic 2d datasets from the cell tracking challenge (PhC-C2DL-PSC, ~70k nodes), calculating CTCMetrics takes > 1 minute right now.

This seems to be, amongst other things, due to some basic attribute getting from the graph in the matching of GT and predicted nodes.

I anticipate the following challenges:

Has anyone else already given this some more thoughts and would like to pitch in, with either comments or coding together?

Cheers

Ben

Topics

What types of changes are you suggesting? Put an x in the boxes that apply.

Which topics does your change affect? Put an x in the boxes that apply.

Priority

Are you interested in contributing?

DragaDoncila commented 5 months ago

Thanks for opening the conversation @bentaculum!

rustworkx is NOT a drop-in replacement for networkx, and not even a subset of networkx. We currently use networkx-only convenience functions like from_pandas_edgelist() in graph construction. This there requires dealing with the different API.

I think this is ok. The convenience of networkx is fine but for this package I'm more concerned about speed than I am about the convenenience of the internals. I do think from the perspective of the user we should be taking networkx objects and returning networkx objects, casting/converting to rustworkx internally only (once on load, once on return). And of course, for use with CTC data, it should be completely transparent to the user via our loader.

It is not certain that there will be speedups for traccuracy. The rustworkx benchmarks I find are promising but limited to certain high-level algorithms.

I suppose we will need to check! Have you profiled the metrics computation by the way? I'd be interested to see a profile - maybe there's still simple things we can do on our end.

rustworkx requires node ids to be int. We currently use strings of format segmentation-ID_time.

I have no strong preference for keeping the string IDs (they kinda annoy me actually because all my other graphs have int IDs), but I undersand the initial driver of wanting them to be meaningful for a user.

bentaculum commented 5 months ago

I suppose we will need to check! Have you profiled the metrics computation by the way? I'd be interested to see a profile - maybe there's still simple things we can do on our end.

Fresh profiling with ipython's %lprun. Somehow faster than yesterday ... Most of the time is spent on getting the label_to_id mappings from the graphs.

from traccuracy import run_metrics
from traccuracy.loaders import load_ctc_data
from traccuracy.metrics import CTCMetrics
from traccuracy.matchers import CTCMatcher

gt_data = load_ctc_data(
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA',
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA/man_track.txt'
)
pred_data = load_ctc_data(
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA',
    '/Users/gallusse/data/celltracking/ctc/PhC-C2DL-PSC/train/01_GT/TRA/man_track.txt'
)

%time ctc_results = run_metrics(gt_data=gt_data,pred_data=pred_data,matcher=CTCMatcher(),metrics=[CTCMetrics()],)

%lprun -f CTCMatcher._compute_mapping -T profile.txt ctc_results = run_metrics(gt_data=gt_data,pred_data=pred_data,matcher=CTCMatcher(),metrics=[CTCMetrics()],)
Matching frames: 100%|█████████████████████████████████████████████████| 300/300 [00:14<00:00, 20.52it/s]
INFO:traccuracy.matchers._base:Matched 71403 out of 71403 ground truth nodes.
INFO:traccuracy.matchers._base:Matched 71403 out of 71403 predicted nodes.
Evaluating nodes: 100%|████████████████████████████████████████| 71403/71403 [00:00<00:00, 594404.37it/s]
Evaluating FP edges: 100%|█████████████████████████████████████| 71201/71201 [00:00<00:00, 919657.80it/s]
Evaluating FN edges: 100%|████████████████████████████████████| 71201/71201 [00:00<00:00, 1002065.74it/s]
CPU times: user 16.5 s, sys: 895 ms, total: 17.4 s
Wall time: 17.6 s
Timer unit: 1e-09 s

Total time: 31.7674 s
File: /Users/gallusse/code/traccuracy/src/traccuracy/matchers/_ctc.py
Function: _compute_mapping at line 29

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    29                                               def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
    30                                                   """Run ctc matching
    31                                           
    32                                                   Args:
    33                                                       gt_graph (TrackingGraph): Tracking graph object for the gt
    34                                                       pred_graph (TrackingGraph): Tracking graph object for the pred
    35                                           
    36                                                   Returns:
    37                                                       traccuracy.matchers.Matched: Matched data object containing the CTC mapping
    38                                           
    39                                                   Raises:
    40                                                       ValueError: if GT and pred segmentations are None or are not the same shape
    41                                                   """
    42         1       1000.0   1000.0      0.0          gt = gt_graph
    43         1          0.0      0.0      0.0          pred = pred_graph
    44         1       1000.0   1000.0      0.0          gt_label_key = gt_graph.label_key
    45         1          0.0      0.0      0.0          pred_label_key = pred_graph.label_key
    46         1       1000.0   1000.0      0.0          G_gt, mask_gt = gt, gt.segmentation
    47         1          0.0      0.0      0.0          G_pred, mask_pred = pred, pred.segmentation
    48                                           
    49         1          0.0      0.0      0.0          if mask_gt is None or mask_pred is None:
    50                                                       raise ValueError("Segmentation is None, cannot perform matching")
    51                                           
    52         1       4000.0   4000.0      0.0          if mask_gt.shape != mask_pred.shape:
    53                                                       raise ValueError("Segmentation shapes must match between gt and pred")
    54                                           
    55         1          0.0      0.0      0.0          mapping = []
    56                                                   # Get overlaps for each frame
    57       302   41774000.0 138324.5      0.1          for i, t in enumerate(
    58         2     291000.0 145500.0      0.0              tqdm(
    59         1       2000.0   2000.0      0.0                  range(gt.start_frame, gt.end_frame),
    60         1          0.0      0.0      0.0                  desc="Matching frames",
    61                                                       )
    62                                                   ):
    63       300     250000.0    833.3      0.0              gt_frame = mask_gt[i]
    64       300      91000.0    303.3      0.0              pred_frame = mask_pred[i]
    65       300     327000.0   1090.0      0.0              gt_frame_nodes = gt.nodes_by_frame[t]
    66       300     244000.0    813.3      0.0              pred_frame_nodes = pred.nodes_by_frame[t]
    67                                           
    68                                                       # get the labels for this frame
    69       600 2552632000.0    4e+06      8.0              gt_labels = dict(
    70       600     424000.0    706.7      0.0                  filter(
    71       300     147000.0    490.0      0.0                      lambda item: item[0] in gt_frame_nodes,
    72       300        1e+10    4e+07     34.3                      nx.get_node_attributes(G_gt.graph, gt_label_key).items(),
    73                                                           )
    74                                                       )
    75       300    9992000.0  33306.7      0.0              gt_label_to_id = {v: k for k, v in gt_labels.items()}
    76                                           
    77       600 2465360000.0    4e+06      7.8              pred_labels = dict(
    78       600     470000.0    783.3      0.0                  filter(
    79       300     182000.0    606.7      0.0                      lambda item: item[0] in pred_frame_nodes,
    80       300        1e+10    4e+07     34.3                      nx.get_node_attributes(G_pred.graph, pred_label_key).items(),
    81                                                           )
    82                                                       )
    83       300    9263000.0  30876.7      0.0              pred_label_to_id = {v: k for k, v in pred_labels.items()}
    84                                           
    85       300      48000.0    160.0      0.0              (
    86       300      84000.0    280.0      0.0                  overlapping_gt_labels,
    87       300      65000.0    216.7      0.0                  overlapping_pred_labels,
    88       300     586000.0   1953.3      0.0                  intersection,
    89       300 4715872000.0    2e+07     14.8              ) = get_labels_with_overlap(gt_frame, pred_frame)
    90                                           
    91     72093    6224000.0     86.3      0.0              for i in range(len(overlapping_gt_labels)):
    92     71793    9997000.0    139.2      0.0                  gt_label = overlapping_gt_labels[i]
    93     71793    8246000.0    114.9      0.0                  pred_label = overlapping_pred_labels[i]
    94                                                           # CTC metrics only match comp IDs to a single GT ID if there is majority overlap
    95     71793    9167000.0    127.7      0.0                  if intersection[i] > 0.5:
    96    142806   16758000.0    117.3      0.1                      mapping.append(
    97     71403  116304000.0   1628.8      0.4                          (gt_label_to_id[gt_label], pred_label_to_id[pred_label])
    98                                                               )
    99                                           
   100         1       2000.0   2000.0      0.0          return Matched(gt_graph, pred_graph, mapping)
bentaculum commented 5 months ago

I coded up a simple speed-up without swapping the backend for now #148.

bentaculum commented 5 months ago

I do think from the perspective of the user we should be taking networkx objects and returning networkx objects, casting/converting to rustworkx internally only (once on load, once on return). And of course, for use with CTC data, it should be completely transparent to the user via our loader.

I like this, agreed.