AllenInstitute / aics-segmentation

AICS Segmentation (One-Way) Mirror
Other
44 stars 7 forks source link

Running aicssegmentation code with a dask array and compute? #5

Open cudmore opened 4 years ago

cudmore commented 4 years ago

I am trying to run a few aics-segmentation functions on a dask array so I can process a number of stacks in parallel.

For example aicssegmentation.core.vessel.filament_3d_wrapper ... 1) If I run it on a dask array of length 1, it completes 1x stack in ~20 seconds with minimal CPU usage. This is about the same as running without a wrapping dask array ... good. 2) If I run it on a dask array of length 4, it completes each 1x stack in ~600 seconds with CPU looking like the 1x case. The 4x stacks are run in parallel but are not increasing CPU usage and are ~30 times slower than a 1x stack? [update], ran it again with a np.float and each call to filament_3d_wrapper when run across 4x stacks took ~1240 seconds, yikes!

I started looking at the source and after some tracing came up with no obvious reason. All I see is normal Python/NumPy/SciPy code? Seem to remember that aics-segmentation has a set of batch functions? Should I use that instead? Any links to example code?

Here is some sample code. In particular, scipy.ndimage.median_filter seems to work fine (runs in parallel and maxes out CPU) but filament_3d_wrapper runs >30x slower and does not max out the CPU (looks like usage at 1x stack).

import time
import numpy as np
import scipy

import dask
import dask.array as da

from aicssegmentation.core.vessel import filament_3d_wrapper

def myRun(path, commonShape, common_dtype):

    # create fake data
    stackData = np.random.normal(loc=100, scale=10, size=commonShape)
    #stackData = stackData.astype(common_dtype)

    # takes about 9 seconds if we have 1x in dask array
        # and still 9 seconds if we have 4x in dask array
    medianKernelSize = (3,4,4)
    print('  median filter', path)
    startTime = time.time()
    #
    smoothData = scipy.ndimage.median_filter(stackData, size=medianKernelSize)
    #
    stopTime = time.time()
    print('    median filter done in', round(stopTime-startTime,2), 'seconds', path)

    # takes about 19 seconds if we have 1x in dask array
        # but 500+ seconds if we have 4x in dask array
    print('  filament_3d_wrapper', path)
    startTime = time.time()
    #
    f3_param=[[1, 0.01]]
    filamentData = filament_3d_wrapper(smoothData, f3_param)
    filamentData = filamentData.astype(np.uint8)
    #
    stopTime = time.time()
    print('    filament_3d_wrapper done in', round(stopTime-startTime,2), 'seconds', path)

if __name__ == '__main__':

    # if I feed dask 1x stacks
    # filament_3d_wrapper returns in about 19 seconds (per stack)
    filenames = ['1']

    # if I feed dask 4x stacks
    # filament_3d_wrapper will run all 4 in parallel but CPU usage does not increase by 4x,
        # looks like I am running just 1x
    # filament_3d_wrapper returns in about 550-650 seconds (per stack)
    filenames = ['1', '2', '3', '4']

    # da.from_delayed() needs to know the shape and dtype it will work with?
    commonShape = (64, 512, 512)
    common_dtype = np.float #np.uint8

    # wrap myRun() function as a dask.delayed()
    myRun_Dask = dask.delayed(myRun)

    lazy_arrays = [dask.delayed(myRun_Dask)(filename, commonShape, common_dtype) for filename in filenames]

    lazy_arrays = [da.from_delayed(x, shape=commonShape, dtype=common_dtype) for x in lazy_arrays]

    x = da.block(lazy_arrays)

    x.compute()
evamaxfield commented 4 years ago

Hey! I am doing some initial checking as to what may be going on but hard to tell right now. Can I ask how much memory your machine has?

jxchen01 commented 4 years ago

https://github.com/AllenInstitute/aics-segmentation/blob/master/aicssegmentation/bin/batch_processing.py

This is the batch processing function provided in aics-segmentation. It is nothing more than looping through the files one by one.

Adding dask support seems to be an important feature for the next release. I will look into it.