brainglobe / cellfinder-core

Standalone cellfinder cell detection algorithm
https://brainglobe.info/documentation/cellfinder/index.html
BSD 3-Clause "New" or "Revised" License
19 stars 16 forks source link

Convert `multiprocessing` steps to `dask` #118

Closed dstansby closed 1 year ago

dstansby commented 1 year ago

The goal here is to remove custom multiprocessing code, and replace it with a series of dask delayed calls that are built into a task graph and then executed using dask. See the comment below for a mock up of how this would work.

The current blocker on this is everything conusmed by a dask delayed object must be pickle-able, but numba jitclasses are not. This prevents us from using dask with the structure detection code at the moment.

dstansby commented 1 year ago

Here's an example task graph for 10 planes and a ball filter depth of 3: dask_graph

import dask.array as da
from dask import delayed

# Number of planes
nz = 10
# Number of planes processed by ball filter at one time
ball_filter_depth = 3

image = da.ones((nz, 200, 300))

# cellfinder_core.detect.filters.plane.get_tile_mask()
@delayed
def get_tile_mask(plane):
    mask = plane.copy()
    return plane, mask

# cellfinder_core.detect.filters.volume.BallFilter.walk()
@delayed
def ball_filter_walk(planes):
    # The ball filter runs and returns the middle plane
    return planes[1]

# cellfinder_core.detect.filters.volume.
@delayed
def detect_structures(plane, previous_plane, structures):
    return plane

tile_masks = []
for plane in image:
    tile_masks.append(get_tile_mask(plane))

ball_filtered = []
for i in range(nz - ball_filter_depth):
    ball_filtered.append(ball_filter_walk(tile_masks[i:i+ball_filter_depth]))

previous_plane = None
structures = []
for i in range(nz - ball_filter_depth):
    previous_plane = detect_structures(ball_filtered[i], previous_plane)

previous_plane = delayed(previous_plane)
previous_plane.visualize(filename='dask_graph.png', collapse_outputs=True)
dstansby commented 1 year ago

I've managed to get the first step (tile filtering) working with dask. This required turning TileProcessor into a function, I'm guessing because there are issues with passing the same class around different dask workers. From a code readabiltiy point of view this isn't an issue, as the class only had the one method anyway.

dstansby commented 1 year ago

Sigh, the currently blocker is now in the final step, structure detection. This is currently done using the CellDetector class, which is JIT compiled by numba. It's not currently possible to picke numba jitclasses, or numba typed dictionaries, which prevents the class or numba dictionaries from being used with dask.