seung-lab / kimimaro

Skeletonize densely labeled 3D image segmentations with TEASAR. (Medial Axis Transform)
GNU General Public License v3.0
136 stars 23 forks source link

skeletontricks.find_target scaling issue #64

Closed chinasaur closed 3 years ago

chinasaur commented 3 years ago

Thanks for this very useful package!

We had scaling issues with the raster scan of the whole segment bounding volume in skeleton_tricks.find_target when using 512x512x512 block sizes containing astrocytes that span most of the block and require many skeleton paths. (We don't use any premarked targets, and were using invalidation scale 2.)

This was fairly easy to fix by caching targets in descending distance order and then only scanning the remaining cached target locations on each find_target call. I did this only for segments whose bounding volume was more than 1e6 voxels, but it's probably okay to just do for all cases. I can send you a patch if it's helpful.

william-silversmith commented 3 years ago

This would be a great contribution! Can you say a little more about how you identify the descending targets? This is something that has bugged me for a while, but I never figured out a good way to handle it.

william-silversmith commented 3 years ago

Just occurred to me that it might be possible to extract the extreme points from the computation of the distance field by isolating voxels that have no neighbor larger than themselves and insert them into a list. Your idea might be simpler to implement though (and has the benefit of already being shown to work).

chinasaur commented 3 years ago

Since it's hard to predict which points will get invalidated, I think the best we can do is to extract all the mask locations and sort them by descending DAF. Then on each find_target call go through the list of mask locations in order from current_index to the first still valid location (do this in Cython). That's the next target and the new current_index. In my tests, this was fast enough. It helps to use flat linearized indices for the mask locations.

I think this is strictly correct, although doesn't always yield the exact same result as the current implementation due to multiple points with the same DAF no longer being found in the same (i.e. raster) order.

Something like: https://gist.github.com/chinasaur/f9b547b614a52c1f5de16513577342a3

william-silversmith commented 3 years ago

That works well! Thank you for providing some example code.

One of the design goals of Kimimaro is to keep memory usage as low as possible. Point clouds can become quite heavy with a worst case scenario of VOXELS * (4-6x) depending on the representation (in this case amounting to an extra 536-805 MB) The memory is kept low so that cheap preemptible/spot instances can be fully loaded at one process per core in the cloud and secondarily so grad students with bad laptops can use it.

The glia/astrocyte problem has bothered me for a long time so given we have two proposed solutions, something is going to make it in. Would it be alright if I give the distance field approach a try? The extreme points should pop out almost for free at minimal overhead. In the worst case, the point cloud should be the size of the discretized surface of a sphere (all ties), so it would be a significantly smaller than the volume approach. The worst case (I think), assuming VOXELS is a sphere would be about 4π(256)^2 * (4 to 6) = 3 to 5 MB.

If it works out, we could have a Kimimaro that is faster across many different shapes including glia/astrocytes as the target selection time will be radically minimized. The target selection problem is one of the last remaining "quadratic" (really O(targets * voxels)) algorithms in here.

chinasaur commented 3 years ago

What I sent does take additional memory, but I'm not sure that's a big concern. Given we already have an input block size (e.g. 512, 512, 512); a DBF; a mask, DAF, and PDRF that are potentially close to full block size, that's multiple volumes on the order 1e8 voxels. Worst case the memory for the targets cache is another 1e8 entries, but for real segments the worst cases I've seen are 1e6-1e7 entries.

If I understand the extreme points proposal, I think it should generally produce good targets. But it seems there could be edge cases. For example if two processes of the same neuron touch tip to tip, can't one path invalidate the extreme points of the nearby tip, resulting in no extreme point remaining valid?

chinasaur commented 3 years ago

Follow-up on the memory use: I think once you've computed the daf_indices (descending DAF sorted mask indices) you can actually free the DAF volume?

william-silversmith commented 3 years ago

I think once you've computed the daf_indices (descending DAF sorted mask indices) you can actually free the DAF volume?

That is an excellent point.

In the below example, that would drop our peak memory usage to about 4 GB and makes this proposal much more palatable. That obviates the memory concern, but I'd still like to entertain the extreme points proposal because I suspect it will work and will both retain the advantage you mentioned and obviate some per-object processing.

image

But it seems there could be edge cases. For example if two processes of the same neuron touch tip to tip, can't one path invalidate the extreme points of the nearby tip, resulting in no extreme point remaining valid?

I think in this case there would be either a single winner or ties. Since the condition of requiring no greater foreground voxels would add both ties to the list, I think the resulting list is still valid. It would be possible to recapitulate the current behavior by sorting the resultant candidate targets in raster order.

What do you think?

chinasaur commented 3 years ago

Yes, if getting the initial mask indices could be done in raster order, and then the argsort by -DAF could be done as stable sort, then I think the behavior should match current exactly, and not just in spirit.

I'm all for trying the extreme points, but still concerned about edge cases. The tip-to-tip is not necessarily just cases where the voxels touch, but also cases where the tips get close enough to be in each others' invalidation cube/ball, right? Or am I misremembering how invalidation works w.r.t. being able to jump across a small gap to a nearby tip?

william-silversmith commented 3 years ago

Oh I see what you are saying. I'm the one that forgot that aspect of invalidation. 😆 I think in that case you should submit a patch at your convenience.

I think the question then, as you implied in your first post, is how does the cached approach affect computation time for non-highly branched cells? If it slows things down, we could restrict it to glia, otherwise use it all the time.

william-silversmith commented 3 years ago

To quantify what the problem is here. I ran Kimimaro on connectomics.py on ID 28336523 (presumed glia) with some reasonable parameters.

find_target is called 244 times and costs 1.5 sec per hit.

Total time: 510.19 s
File: /Users/wms/code/kimimaro/kimimaro/trace.py
Function: compute_paths at line 182

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   182                                           @profile
   183                                           def compute_paths(
   184                                               root, labels, DBF, DAF, 
   185                                               parents, scale, const, anisotropy, 
   186                                               soma_mode, soma_radius, fix_branching,
   187                                               manual_targets_before, manual_targets_after,
   188                                               max_paths, voxel_graph
   189                                             ):
   190                                             """
   191                                             Given the labels, DBF, DAF, dijkstra parents,
   192                                             and associated invalidation knobs, find the set of paths 
   193                                             that cover the object. Somas are given special treatment
   194                                             in that we attempt to cull vertices within a radius of the
   195                                             root vertex.
   196                                             """
   197        28         23.0      0.8      0.0    invalid_vertices = {}
   198        28         26.0      0.9      0.0    paths = []
   199        28       3992.0    142.6      0.0    valid_labels = np.count_nonzero(labels)
   200        28         37.0      1.3      0.0    root = tuple(root)
   201                                           
   202        28         20.0      0.7      0.0    if soma_mode:
   203                                               invalid_vertices[root] = True
   204                                           
   205        28         22.0      0.8      0.0    if max_paths is None:
   206        28         21.0      0.8      0.0      max_paths = valid_labels
   207                                           
   208        28         45.0      1.6      0.0    if len(manual_targets_before) + len(manual_targets_after) >= max_paths:
   209                                               return []
   210                                           
   211       834        860.0      1.0      0.0    while (valid_labels > 0 or manual_targets_before or manual_targets_after) \
   212       403        617.0      1.5      0.0      and len(paths) < max_paths:
   213                                           
   214       403        333.0      0.8      0.0      if manual_targets_before:
   215       159        267.0      1.7      0.0        target = manual_targets_before.pop()
   216       244        203.0      0.8      0.0      elif valid_labels == 0:
   217                                                 target = manual_targets_after.pop()
   218                                               else:
   219       244  368778687.0 1511388.1     72.3        target = kimimaro.skeletontricks.find_target(labels, DAF)
   220                                           
   221       403        577.0      1.4      0.0      if fix_branching:
   222                                                 # faster to trace from target to root than root to target
   223                                                 # because that way local exploration finds any zero weighted path
   224                                                 # and finishes vs exploring from the neighborhood of the entire zero
   225                                                 # weighted path
   226       806  114718918.0 142331.2     22.5        path = dijkstra3d.dijkstra(
   227       403        342.0      0.8      0.0          parents, target, root, 
   228       403        303.0      0.8      0.0          bidirectional=soma_mode, voxel_graph=voxel_graph
   229                                                 )
   230                                               else:
   231                                                 path = dijkstra3d.path_from_parents(parents, target)
   232                                               
   233       403        783.0      1.9      0.0      if soma_mode:
   234                                                 dist_to_soma_root = np.linalg.norm(anisotropy * (path - root), axis=1)
   235                                                 # remove all path points which are within soma_radius of root
   236                                                 path = np.concatenate(
   237                                                   (path[:1,:], path[dist_to_soma_root > soma_radius, :])
   238                                                 )
   239                                           
   240       403        574.0      1.4      0.0      if valid_labels > 0:
   241       806   24748296.0  30705.1      4.9        invalidated, labels = kimimaro.skeletontricks.roll_invalidation_cube(
   242       403        351.0      0.9      0.0          labels, DBF, path, scale, const, 
   243       403        353.0      0.9      0.0          anisotropy=anisotropy, invalid_vertices=invalid_vertices,
   244                                                 )
   245       403       1207.0      3.0      0.0        valid_labels -= invalidated
   246                                           
   247    380773     330046.0      0.9      0.1      for vertex in path:
   248    380370     609743.0      1.6      0.1        invalid_vertices[tuple(vertex)] = True
   249    380370     296231.0      0.8      0.1        if fix_branching:
   250    380370     695440.0      1.8      0.1          parents[tuple(vertex)] = 0.0
   251                                           
   252       403       1187.0      2.9      0.0      paths.append(path)
   253                                           
   254        28         23.0      0.8      0.0    return paths
william-silversmith commented 3 years ago

Here's the profile for all labels (including the glia):

Total time: 531.035 s
File: /Users/wms/code/kimimaro/kimimaro/trace.py
Function: compute_paths at line 182

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   182                                           @profile
   183                                           def compute_paths(
   184                                               root, labels, DBF, DAF, 
   185                                               parents, scale, const, anisotropy, 
   186                                               soma_mode, soma_radius, fix_branching,
   187                                               manual_targets_before, manual_targets_after,
   188                                               max_paths, voxel_graph
   189                                             ):
   190                                             """
   191                                             Given the labels, DBF, DAF, dijkstra parents,
   192                                             and associated invalidation knobs, find the set of paths 
   193                                             that cover the object. Somas are given special treatment
   194                                             in that we attempt to cull vertices within a radius of the
   195                                             root vertex.
   196                                             """
   197      2124       1732.0      0.8      0.0    invalid_vertices = {}
   198      2124       2082.0      1.0      0.0    paths = []
   199      2124     220700.0    103.9      0.0    valid_labels = np.count_nonzero(labels)
   200      2124       3211.0      1.5      0.0    root = tuple(root)
   201                                           
   202      2124       1643.0      0.8      0.0    if soma_mode:
   203         1          2.0      2.0      0.0      invalid_vertices[root] = True
   204                                           
   205      2124       1724.0      0.8      0.0    if max_paths is None:
   206      2124       1597.0      0.8      0.0      max_paths = valid_labels
   207                                           
   208      2124       4172.0      2.0      0.0    if len(manual_targets_before) + len(manual_targets_after) >= max_paths:
   209                                               return []
   210                                           
   211     11512      10904.0      0.9      0.0    while (valid_labels > 0 or manual_targets_before or manual_targets_after) \
   212      4694       5490.0      1.2      0.0      and len(paths) < max_paths:
   213                                           
   214      4694       3901.0      0.8      0.0      if manual_targets_before:
   215      2176       4720.0      2.2      0.0        target = manual_targets_before.pop()
   216      2518       2207.0      0.9      0.0      elif valid_labels == 0:
   217                                                 target = manual_targets_after.pop()
   218                                               else:
   219      2518  293030790.0 116374.4     55.2        target = kimimaro.skeletontricks.find_target(labels, DAF)
   220                                           
   221      4694       4624.0      1.0      0.0      if fix_branching:
   222                                                 # faster to trace from target to root than root to target
   223                                                 # because that way local exploration finds any zero weighted path
   224                                                 # and finishes vs exploring from the neighborhood of the entire zero
   225                                                 # weighted path
   226      9388  192275396.0  20481.0     36.2        path = dijkstra3d.dijkstra(
   227      4694       3881.0      0.8      0.0          parents, target, root, 
   228      4694       3522.0      0.8      0.0          bidirectional=soma_mode, voxel_graph=voxel_graph
   229                                                 )
   230                                               else:
   231                                                 path = dijkstra3d.path_from_parents(parents, target)
   232                                               
   233      4694       7379.0      1.6      0.0      if soma_mode:
   234        17       1594.0     93.8      0.0        dist_to_soma_root = np.linalg.norm(anisotropy * (path - root), axis=1)
   235                                                 # remove all path points which are within soma_radius of root
   236        34        175.0      5.1      0.0        path = np.concatenate(
   237        17        389.0     22.9      0.0          (path[:1,:], path[dist_to_soma_root > soma_radius, :])
   238                                                 )
   239                                           
   240      4694       5228.0      1.1      0.0      if valid_labels > 0:
   241      9092   38914142.0   4280.0      7.3        invalidated, labels = kimimaro.skeletontricks.roll_invalidation_cube(
   242      4546       4002.0      0.9      0.0          labels, DBF, path, scale, const, 
   243      4546       3826.0      0.8      0.0          anisotropy=anisotropy, invalid_vertices=invalid_vertices,
   244                                                 )
   245      4546       9298.0      2.0      0.0        valid_labels -= invalidated
   246                                           
   247   1299525    1138272.0      0.9      0.2      for vertex in path:
   248   1294831    2069974.0      1.6      0.4        invalid_vertices[tuple(vertex)] = True
   249   1294831     996810.0      0.8      0.2        if fix_branching:
   250   1294831    2288569.0      1.8      0.4          parents[tuple(vertex)] = 0.0
   251                                           
   252      4694      10724.0      2.3      0.0      paths.append(path)
   253                                           
   254      2124       1905.0      0.9      0.0    return paths
william-silversmith commented 3 years ago

With the glia excluded:

Total time: 133.272 s
File: /Users/wms/code/kimimaro/kimimaro/trace.py
Function: compute_paths at line 182

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   182                                           @profile
   183                                           def compute_paths(
   184                                               root, labels, DBF, DAF, 
   185                                               parents, scale, const, anisotropy, 
   186                                               soma_mode, soma_radius, fix_branching,
   187                                               manual_targets_before, manual_targets_after,
   188                                               max_paths, voxel_graph
   189                                             ):
   190                                             """
   191                                             Given the labels, DBF, DAF, dijkstra parents,
   192                                             and associated invalidation knobs, find the set of paths 
   193                                             that cover the object. Somas are given special treatment
   194                                             in that we attempt to cull vertices within a radius of the
   195                                             root vertex.
   196                                             """
   197      2096       1701.0      0.8      0.0    invalid_vertices = {}
   198      2096       2018.0      1.0      0.0    paths = []
   199      2096     221574.0    105.7      0.2    valid_labels = np.count_nonzero(labels)
   200      2096       3108.0      1.5      0.0    root = tuple(root)
   201                                           
   202      2096       1630.0      0.8      0.0    if soma_mode:
   203         1          2.0      2.0      0.0      invalid_vertices[root] = True
   204                                           
   205      2096       1713.0      0.8      0.0    if max_paths is None:
   206      2096       1581.0      0.8      0.0      max_paths = valid_labels
   207                                           
   208      2096       4219.0      2.0      0.0    if len(manual_targets_before) + len(manual_targets_after) >= max_paths:
   209                                               return []
   210                                           
   211     10678      10288.0      1.0      0.0    while (valid_labels > 0 or manual_targets_before or manual_targets_after) \
   212      4291       4803.0      1.1      0.0      and len(paths) < max_paths:
   213                                           
   214      4291       3493.0      0.8      0.0      if manual_targets_before:
   215      2017       4399.0      2.2      0.0        target = manual_targets_before.pop()
   216      2274       1936.0      0.9      0.0      elif valid_labels == 0:
   217                                                 target = manual_targets_after.pop()
   218                                               else:
   219      2274   39832241.0  17516.4     29.9        target = kimimaro.skeletontricks.find_target(labels, DAF)
   220                                           
   221      4291       3936.0      0.9      0.0      if fix_branching:
   222                                                 # faster to trace from target to root than root to target
   223                                                 # because that way local exploration finds any zero weighted path
   224                                                 # and finishes vs exploring from the neighborhood of the entire zero
   225                                                 # weighted path
   226      8582   74985493.0   8737.5     56.3        path = dijkstra3d.dijkstra(
   227      4291       3435.0      0.8      0.0          parents, target, root, 
   228      4291       3237.0      0.8      0.0          bidirectional=soma_mode, voxel_graph=voxel_graph
   229                                                 )
   230                                               else:
   231                                                 path = dijkstra3d.path_from_parents(parents, target)
   232                                               
   233      4291       6510.0      1.5      0.0      if soma_mode:
   234        17       1776.0    104.5      0.0        dist_to_soma_root = np.linalg.norm(anisotropy * (path - root), axis=1)
   235                                                 # remove all path points which are within soma_radius of root
   236        34        209.0      6.1      0.0        path = np.concatenate(
   237        17        439.0     25.8      0.0          (path[:1,:], path[dist_to_soma_root > soma_radius, :])
   238                                                 )
   239                                           
   240      4291       4721.0      1.1      0.0      if valid_labels > 0:
   241      8286   13666231.0   1649.3     10.3        invalidated, labels = kimimaro.skeletontricks.roll_invalidation_cube(
   242      4143       3641.0      0.9      0.0          labels, DBF, path, scale, const, 
   243      4143       3482.0      0.8      0.0          anisotropy=anisotropy, invalid_vertices=invalid_vertices,
   244                                                 )
   245      4143       7617.0      1.8      0.0        valid_labels -= invalidated
   246                                           
   247    918749     800209.0      0.9      0.6      for vertex in path:
   248    914458    1421725.0      1.6      1.1        invalid_vertices[tuple(vertex)] = True
   249    914458     705261.0      0.8      0.5        if fix_branching:
   250    914458    1548042.0      1.7      1.2          parents[tuple(vertex)] = 0.0
   251                                           
   252      4291       9295.0      2.2      0.0      paths.append(path)
   253                                           
   254      2096       1917.0      0.9      0.0    return paths
william-silversmith commented 3 years ago

This doesn't help so much with glia, but I just realized there is one condition in which the extreme points are always valid: at the very start before any invalidation occurs. Most shapes in a given task require only one or a few targets. I think the max point can likely be acquired from dijkstra.euclidean_distance_field cheaply. If so, we can eliminate most calls to find_target and have a substantial savings overall.

This savings would apply to both find_root and find_target in the case no root or no prior target exists. No prior root would exist if fix_borders=False.

I think if we combine our approaches, we can crush down the running time of Kimimaro substantially.

william-silversmith commented 3 years ago

I tried it out, but the results weren't anywhere as good as I'd hoped:

   217      4291       3491.0      0.8      0.0      if manual_targets_before:
   218      2959       5866.0      2.0      0.0        target = manual_targets_before.pop()
   219      1332       1248.0      0.9      0.0      elif valid_labels == 0:
   220                                                 target = manual_targets_after.pop()
   221                                               else:
   222      1332   38968320.0  29255.5     29.0        target = kimimaro.skeletontricks.find_target(labels, DAF)

There's a sharp reduction in the number of calls to find_target, but it seems to be dominated by the more complex cases.

chinasaur commented 3 years ago

I'm not sure I totally followed how you were using the EDT for this?

If you want to try the cached target finding, please feel free to build on the Gist; that's very close to what I've been using. (Just add the free of the DAF.) I found that the cases where existing find_target bogged down were mostly those where the segment mask bounding volume is over 1e6 voxels. If less than that, the raster scan is faster than caching, or at least comparable. For those smaller bounding volume segments, the memory benefit of freeing the DAF is also smaller. You can probably pick a threshold similar to 1e6, but perhaps your optimal point is a bit different depending on your memory use target.

If you prefer a pull request I can send, but it will take a little reworking since my use of the package is currently wrapped up in our Apache Beam skeletonization pipeline implementation.

william-silversmith commented 3 years ago

I'm not sure I totally followed how you were using the EDT for this?

The DAF comes from the dijkstra3d.euclidean_distance_field calculation. I locally edited my copy of dijkstra3d to compute (one of) the maximum locations and verified that it didn't affect the running time significantly. That should return the first target.

DAF, target = dijkstra3d.euclidean_distance_field(
    labels, root, 
    anisotropy=anisotropy, 
    free_space_radius=free_space_radius,
    voxel_graph=voxel_graph,
    return_max_location=True,
  )
# versus
target = kimimaro.skeletontricks.find_target(labels, DAF)

I experimented with doing it for both find_root and for avoiding calling find_target the first time. The former was worth very little and the latter worth about 0.5% - 1% of the running time. With fix_borders=False, the former was worth a few percent, so it might be worth integrating. I'll take a few percent performance improvement if it doesn't distort the code too much. Kimimaro is all about the gainz.

If you want to try the cached target finding, please feel free to build on the Gist; that's very close to what I've been using.

Okay! I'll give that a shot. Thanks for the code and discussion, this is a very helpful improvement. I'll let you know how it goes.

william-silversmith commented 3 years ago

This is very cool. Effect of your cached approach on glia (ID 28336523).

510 sec / 145 sec = 3.5x speedup.

Total time: 144.511 s
File: /Users/wms/code/kimimaro/kimimaro/trace.py
Function: compute_paths at line 186

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   186                                           @profile
   187                                           def compute_paths(
   188                                               root, labels, DBF, DAF, 
   189                                               parents, scale, const, anisotropy, 
   190                                               soma_mode, soma_radius, fix_branching,
   191                                               manual_targets_before, manual_targets_after,
   192                                               max_paths, voxel_graph
   193                                             ):
   194                                             """
   195                                             Given the labels, DBF, DAF, dijkstra parents,
   196                                             and associated invalidation knobs, find the set of paths 
   197                                             that cover the object. Somas are given special treatment
   198                                             in that we attempt to cull vertices within a radius of the
   199                                             root vertex.
   200                                             """
   201        28         54.0      1.9      0.0    invalid_vertices = {}
   202        28         30.0      1.1      0.0    paths = []
   203        28       4895.0    174.8      0.0    valid_labels = np.count_nonzero(labels)
   204        28         48.0      1.7      0.0    root = tuple(root)
   205                                           
   206        28         27.0      1.0      0.0    if soma_mode:
   207                                               invalid_vertices[root] = True
   208                                           
   209        28         27.0      1.0      0.0    if max_paths is None:
   210        28         27.0      1.0      0.0      max_paths = valid_labels
   211                                           
   212        28         68.0      2.4      0.0    if len(manual_targets_before) + len(manual_targets_after) >= max_paths:
   213                                               return []
   214                                           
   215        28         28.0      1.0      0.0    target_finder = None
   216        28         39.0      1.4      0.0    def find_target():
   217                                               nonlocal target_finder
   218                                               if target_finder is None:
   219                                                 target_finder = kimimaro.skeletontricks.CachedTargetFinder(labels, DAF[0])
   220                                                 DAF.pop(0)
   221                                               return target_finder.find_target(labels)
   222                                           
   223       834        975.0      1.2      0.0    while (valid_labels > 0 or manual_targets_before or manual_targets_after) \
   224       403        815.0      2.0      0.0      and len(paths) < max_paths:
   225                                           
   226       403        414.0      1.0      0.0      if manual_targets_before:
   227       159        382.0      2.4      0.0        target = manual_targets_before.pop()
   228       244        242.0      1.0      0.0      elif valid_labels == 0:
   229                                                 target = manual_targets_after.pop()
   230                                               else:
   231       244     115673.0    474.1      0.1        target = find_target()
   232                                           
   233       403        440.0      1.1      0.0      if fix_branching:
   234                                                 # faster to trace from target to root than root to target
   235                                                 # because that way local exploration finds any zero weighted path
   236                                                 # and finishes vs exploring from the neighborhood of the entire zero
   237                                                 # weighted path
   238       806  116915677.0 145056.7     80.9        path = dijkstra3d.dijkstra(
   239       403        409.0      1.0      0.0          parents, target, root, 
   240       403        379.0      0.9      0.0          bidirectional=soma_mode, voxel_graph=voxel_graph
   241                                                 )
   242                                               else:
   243                                                 path = dijkstra3d.path_from_parents(parents, target)
   244                                               
   245       403       1027.0      2.5      0.0      if soma_mode:
   246                                                 dist_to_soma_root = np.linalg.norm(anisotropy * (path - root), axis=1)
   247                                                 # remove all path points which are within soma_radius of root
   248                                                 path = np.concatenate(
   249                                                   (path[:1,:], path[dist_to_soma_root > soma_radius, :])
   250                                                 )
   251                                           
   252       403        628.0      1.6      0.0      if valid_labels > 0:
   253       806   25242025.0  31317.6     17.5        invalidated, labels = kimimaro.skeletontricks.roll_invalidation_cube(
   254       403        533.0      1.3      0.0          labels, DBF, path, scale, const, 
   255       403        436.0      1.1      0.0          anisotropy=anisotropy, invalid_vertices=invalid_vertices,
   256                                                 )
   257       403       1592.0      4.0      0.0        valid_labels -= invalidated
   258                                           
   259    380773     405630.0      1.1      0.3      for vertex in path:
   260    380370     690428.0      1.8      0.5        invalid_vertices[tuple(vertex)] = True
   261    380370     367438.0      1.0      0.3        if fix_branching:
   262    380370     759324.0      2.0      0.5          parents[tuple(vertex)] = 0.0
   263                                           
   264       403       1613.0      4.0      0.0      paths.append(path)
   265                                           
   266        28         29.0      1.0      0.0    return paths
william-silversmith commented 3 years ago

With the glia excluded. Looks like there's an overall speedup!

133 / 97 = 1.37x

Total time: 97.0796 s
File: /Users/wms/code/kimimaro/kimimaro/trace.py
Function: compute_paths at line 186

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   186                                           @profile
   187                                           def compute_paths(
   188                                               root, labels, DBF, DAF, 
   189                                               parents, scale, const, anisotropy, 
   190                                               soma_mode, soma_radius, fix_branching,
   191                                               manual_targets_before, manual_targets_after,
   192                                               max_paths, voxel_graph
   193                                             ):
   194                                             """
   195                                             Given the labels, DBF, DAF, dijkstra parents,
   196                                             and associated invalidation knobs, find the set of paths 
   197                                             that cover the object. Somas are given special treatment
   198                                             in that we attempt to cull vertices within a radius of the
   199                                             root vertex.
   200                                             """
   201      2096       3840.0      1.8      0.0    invalid_vertices = {}
   202      2096       2394.0      1.1      0.0    paths = []
   203      2096     225886.0    107.8      0.2    valid_labels = np.count_nonzero(labels)
   204      2096       3683.0      1.8      0.0    root = tuple(root)
   205                                           
   206      2096       2069.0      1.0      0.0    if soma_mode:
   207         1          3.0      3.0      0.0      invalid_vertices[root] = True
   208                                           
   209      2096       2093.0      1.0      0.0    if max_paths is None:
   210      2096       1977.0      0.9      0.0      max_paths = valid_labels
   211                                           
   212      2096       4788.0      2.3      0.0    if len(manual_targets_before) + len(manual_targets_after) >= max_paths:
   213                                               return []
   214                                           
   215      2096       2232.0      1.1      0.0    target_finder = None
   216      2096       3077.0      1.5      0.0    def find_target():
   217                                               nonlocal target_finder
   218                                               if target_finder is None:
   219                                                 target_finder = kimimaro.skeletontricks.CachedTargetFinder(labels, DAF[0])
   220                                                 DAF.pop(0)
   221                                               return target_finder.find_target(labels)
   222                                           
   223     10678      12420.0      1.2      0.0    while (valid_labels > 0 or manual_targets_before or manual_targets_after) \
   224      4291       5754.0      1.3      0.0      and len(paths) < max_paths:
   225                                           
   226      4291       4381.0      1.0      0.0      if manual_targets_before:
   227      2017       4987.0      2.5      0.0        target = manual_targets_before.pop()
   228      2274       2325.0      1.0      0.0      elif valid_labels == 0:
   229                                                 target = manual_targets_after.pop()
   230                                               else:
   231      2274    1919962.0    844.3      2.0        target = find_target()
   232                                           
   233      4291       4862.0      1.1      0.0      if fix_branching:
   234                                                 # faster to trace from target to root than root to target
   235                                                 # because that way local exploration finds any zero weighted path
   236                                                 # and finishes vs exploring from the neighborhood of the entire zero
   237                                                 # weighted path
   238      8582   75624651.0   8812.0     77.9        path = dijkstra3d.dijkstra(
   239      4291       4277.0      1.0      0.0          parents, target, root, 
   240      4291       4058.0      0.9      0.0          bidirectional=soma_mode, voxel_graph=voxel_graph
   241                                                 )
   242                                               else:
   243                                                 path = dijkstra3d.path_from_parents(parents, target)
   244                                               
   245      4291       8446.0      2.0      0.0      if soma_mode:
   246        17       1697.0     99.8      0.0        dist_to_soma_root = np.linalg.norm(anisotropy * (path - root), axis=1)
   247                                                 # remove all path points which are within soma_radius of root
   248        34        222.0      6.5      0.0        path = np.concatenate(
   249        17        436.0     25.6      0.0          (path[:1,:], path[dist_to_soma_root > soma_radius, :])
   250                                                 )
   251                                           
   252      4291       6056.0      1.4      0.0      if valid_labels > 0:
   253      8286   13885204.0   1675.7     14.3        invalidated, labels = kimimaro.skeletontricks.roll_invalidation_cube(
   254      4143       4832.0      1.2      0.0          labels, DBF, path, scale, const, 
   255      4143       4498.0      1.1      0.0          anisotropy=anisotropy, invalid_vertices=invalid_vertices,
   256                                                 )
   257      4143       9538.0      2.3      0.0        valid_labels -= invalidated
   258                                           
   259    918747     998814.0      1.1      1.0      for vertex in path:
   260    914456    1642724.0      1.8      1.7        invalid_vertices[tuple(vertex)] = True
   261    914456     897521.0      1.0      0.9        if fix_branching:
   262    914456    1766279.0      1.9      1.8          parents[tuple(vertex)] = 0.0
   263                                           
   264      4291      11300.0      2.6      0.0      paths.append(path)
   265                                           
   266      2096       2312.0      1.1      0.0    return paths
william-silversmith commented 3 years ago

(black) this change (blue) current master branch

I probably have to do some more work to realize the DAF free that you suggested.

image
chinasaur commented 3 years ago

Looks nice; interesting if it's faster overall without switching depending on the bounding volume size. This is a 512^3 test volume? I suspect there's still some advantage to not caching for the smaller bounding volume objects. At least in my tests a single raster scan find_target was much faster than the caching computations for those small ones; I didn't actually benchmark the full skeletonization for them though.

william-silversmith commented 3 years ago

Yea, it's the connectomics.npy volume in the repository. It's been my standard test volume for years because it has lots of different kinds of objects in it including a glia and a partial soma.

For very small volumes, the extreme points optimization avoids computing the cache if there's only one target needed, so some of that work is already done. I've not seen a ton of improvement deleting DAF (will have to look more carefully), so if we just keep it around, we could pick regular or cached find_target based on the number of initial foreground voxels. However, I am very encouraged that the average time for the cached find_target is only 844.3 µs per hit. It's gonna be hard to beat that. The total time taken by find_target in the second profile was only 1.9 seconds in 97.1 seconds.

This might be good enough.