dask / dask-image

Distributed image processing
http://image.dask.org/en/latest/
BSD 3-Clause "New" or "Revised" License
210 stars 47 forks source link

Improve dask-image imread #179

Closed joaomamede closed 3 years ago

joaomamede commented 3 years ago

The instance that opens each file name is defined as

def _read_frame(fn, i, *, arrayfunc=numpy.asanyarray):
    with pims.open(fn) as imgs:
        return arrayfunc(imgs[i])

In my (now) particular case most of the files have different time, channels , z (single Z projections as of now), y and x. It's a 6 by 5 panel with 3 channels, 2044x2048 totalling 30 files (with Image{i}.ome.tiff ranging from 0 to 29).

I want to lazy load from disk (at least each time frame) and show it assembled.

As of now, I am testing by assembling them to view live cell mode in napari, and I basically create a massive array loaded into RAM that can't go beyond a certain size by:

Currently, only trying a horizontal stitch of 1x4.

print(stack.shape)
n_cols = 1
n_rows = 4
tet = np.zeros((143,2044*n_cols,2048*n_rows))
integrase = np.zeros((143,2044*n_cols,2048*n_rows))
nuclei = np.zeros((143,2044*n_cols,2048*n_rows))

i=0
times=143
for col in range(n_cols):
    for row in range(n_rows):
        tet[:times,2044*col:2044*(col+1),2048*row:2048*(row+1)]= stack[i,:times,:,:,0]
        integrase[:times,2044*col:2044*(col+1),2048*row:2048*(row+1)]= stack[i,:times,:,:,1]
        nuclei[:times,2044*col:2044*(col+1),2048*row:2048*(row+1)]= stack[i,:times,:,:,2]
        i+=1

Find an example:

The way imread (or dask, below) is set-up it reads each file and opens an instance where [i] is called. however reshaping does not work when I try to assemble it into higher X and Y, I assume it fails at the channel level (read errors below)

stack = imread('/home/jmamede/Data/tet/tetMoon20201127/*ome.tiff')

print(stack.shape)
stack = stack.reshape(143,2044, 2048*4,3)
> (4, 143, 2044, 2048, 3)
> ---------------------------------------------------------------------------
> ValueError                                Traceback (most recent call last)
> <ipython-input-14-98a1b37ead9a> in <module>
>      10 
>      11 print(stack.shape)
> ---> 12 stack = stack.reshape(143,2044, 2048*4,3)
> 
> ~/anaconda3/envs/pycuda/lib/python3.7/site-packages/dask/array/core.py in reshape(self, *shape)
>    1795         if len(shape) == 1 and not isinstance(shape[0], Number):
>    1796             shape = shape[0]
> -> 1797         return reshape(self, shape)
>    1798 
>    1799     def topk(self, k, axis=-1, split_every=None):
> 
> ~/anaconda3/envs/pycuda/lib/python3.7/site-packages/dask/array/reshape.py in reshape(x, shape)
>     193 
>     194     # Logic for how to rechunk
> --> 195     inchunks, outchunks = reshape_rechunk(x.shape, shape, x.chunks)
>     196     x2 = x.rechunk(inchunks)
>     197 
> 
> ~/anaconda3/envs/pycuda/lib/python3.7/site-packages/dask/array/reshape.py in reshape_rechunk(inshape, outshape, inchunks)
>      42                 ileft -= 1
>      43             if reduce(mul, inshape[ileft : ii + 1]) != dout:
> ---> 44                 raise ValueError("Shapes not compatible")
>      45 
>      46             for i in range(ileft + 1, ii + 1):  # need single-shape dimensions
> 
> ValueError: Shapes not compatible

OR

lazy_imread = delayed(imread)
lazy_arrays = [lazy_imread(fn) for fn in filelist]
dask_arrays = [da.from_delayed(delayed_reader, shape=(143,2044,2048,3), dtype='uint16') 
                   for delayed_reader in lazy_arrays]
stack = da.stack(dask_arrays,axis=0)

print(stack.shape)
stack = stack.reshape(143,2044, 2048*4,3)
> (4, 143, 2044, 2048, 3)
> ---------------------------------------------------------------------------
> ValueError                                Traceback (most recent call last)
> <ipython-input-15-91398ced57b4> in <module>
>      10 #
>      11 print(stack.shape)
> ---> 12 stack = stack.reshape(143,2044, 2048*4,3)
> 
> ~/anaconda3/envs/pycuda/lib/python3.7/site-packages/dask/array/core.py in reshape(self, *shape)
>    1795         if len(shape) == 1 and not isinstance(shape[0], Number):
>    1796             shape = shape[0]
> -> 1797         return reshape(self, shape)
>    1798 
>    1799     def topk(self, k, axis=-1, split_every=None):
> 
> ~/anaconda3/envs/pycuda/lib/python3.7/site-packages/dask/array/reshape.py in reshape(x, shape)
>     193 
>     194     # Logic for how to rechunk
> --> 195     inchunks, outchunks = reshape_rechunk(x.shape, shape, x.chunks)
>     196     x2 = x.rechunk(inchunks)
>     197 
> 
> ~/anaconda3/envs/pycuda/lib/python3.7/site-packages/dask/array/reshape.py in reshape_rechunk(inshape, outshape, inchunks)
>      42                 ileft -= 1
>      43             if reduce(mul, inshape[ileft : ii + 1]) != dout:
> ---> 44                 raise ValueError("Shapes not compatible")
>      45 
>      46             for i in range(ileft + 1, ii + 1):  # need single-shape dimensions
> 
> ValueError: Shapes not compatible

My question is

Each time we call a different i for given filename with imread, is there a new instance of pims.open() created, or it reuses the same one for each file?

Would something like this work (by starting the variable with imread(fnames_list, channel_to_be_picked_up) )

def _read_frame_improved_JM(fn, i, ch,bundle='zyx', iter='t',*, arrayfunc=numpy.asanyarray):
    with pims.open(fn) as imgs:
        imgs.iter_axes = iter
        imgs.bundle_axes = bundle
        imgs.default_coords['c'] = ch
        return arrayfunc(imgs[i])

or I should create two functions within imread.init.py, one to initialize the file instance (basically like the current _read_frame() ) and another to read each frame with a a certain shape

def _initialize_pims(fn):
    with pims.open(fn) as imgs:
        return imgs[i]

def  _read_frame(pims_object,t,ch, bundle='zyx', iter='t',*, arrayfunc=numpy.asanyarray):
    pims_object.iter_axes=iter
    pims_object.bundle_axes = bundle
    pims_object.default_coords['c'] = ch
    return arrayfunc(imgs[i])

Any guidance is appreciated before I start coding things that might be simpler than what I'm thinking of doing.

Thanks!

GenevieveBuckley commented 3 years ago

My question is Each time we call a different i for given filename with imread, is there a new instance of pims.open() created, or it reuses the same one for each file?

pims.open() is a function, it's not creating a class instance or anything like that here. I hope that helps clear up some of the confusion.

It sounds like you're trying to combine a lot of smaller image files into one large image volume. It might be more straightforward for you to use the block_info keyword argument to dask.array.map_blocks() to specify array-location of the blocks. This is likely to have better performance than reshaping a Dask array after constructing it. The dask array API docs have a small section on block_info here and there is also a page on combining dask arrays here.

joaomamede commented 3 years ago

If I got it right it calls a new pims.open for every i as it needs it. If I do the first modification with iterator and bundler, it should work by defaulting the channel with an argument.

My problem is the lazy loading into dask. I tried to concatenate before and it also failed. I'll try map_blocks again just in case.

Ty

joaomamede commented 3 years ago

This works great!

import pims
import numpy as np
import dask
import dask.array
import warnings
import glob

def initialize_reader(fn,iterator='c',bundler='yx', ch =0, **kwargs):
    """
    Read File and returns a pims object while allowing to set iterator and
    frame shape output

    Parameters
    ----------
    fn : str
        A string with one or multiple filenames, read pims multireader
        for details.
    iterator: str
        slice to iterate, defaults to time 't'.
    bundler: str
        slice shape to output, defaults to 'yx'.
    ch:  int
        only output a one channel as default, defaults to 0.

    Returns
    -------
    array : pims.bioformats.BioformatsReader
        A pims reader to access the contents of all image files in
        the predefined channel
    """

    reader = pims.bioformats.BioformatsReader(fn)
    reader.iter_axes = iterator  # 't' is the default already
    reader.bundle_axes = bundler
    reader.default_coords['c'] = ch
    return reader

def _read_frame(pims_reader,i, arrayfunc=np.asanyarray,**kwargs):
    """
    Read File and returns a pims object while allowing to set
    iterator and frame shape output.

    Parameters
    ----------
    pims_reader : object
        Pims reader object
    i: int
        the file coordinate to output as selected by the initializer
        iterator and default channel

    Returns
    -------
    array : pims.frame.Frame
        Array with the data in the reader current
        shape and default_coords.
    """
    return arrayfunc(pims_reader[i])

def time_stack(fn,nframes=1, ch=0, iterator='t',bundler='yx',arraytype="numpy",**kwargs):
    if arraytype == "numpy":
        arrayfunc = np.asanyarray
    elif arraytype == "cupy":   # pragma: no cover
        import cupy
        arrayfunc = cupy.asanyarray

    # type(test)
    reader = initialize_reader(fn,
            ch=ch,iterator=iterator,bundler=bundler
            )
    shape = (len(reader),) + reader.frame_shape
    dtype = np.dtype(reader.pixel_type)

    if nframes == -1:
        nframes = shape[0]

    if nframes > shape[0]:
        warnings.warn(
            "`nframes` larger than number of frames in file."
            " Will truncate to number of frames in file.",
            RuntimeWarning
        )
    elif shape[0] % nframes != 0:
        warnings.warn(
            "`nframes` does not nicely divide number of frames in file."
            " Last chunk will contain the remainder.",
            RuntimeWarning
        )

    import itertools
    lower_iter, upper_iter = itertools.tee(itertools.chain(
        range(0, shape[0], nframes),
        [shape[0]]
    ))
    next(upper_iter)

    a = []
    # print(type(a))
    for i, j in zip(lower_iter, upper_iter):

        a.append( dask.array.from_delayed(
            dask.delayed(_read_frame)(reader,slice(i,j), arrayfunc=arrayfunc),
            (j - i,) + shape[1:],
            dtype,
            meta=arrayfunc([])
        ))

    a = dask.array.stack(a)
    return a

####running like __main__ starts here.

filelist = glob.glob('/home/jmamede/Data/tet/tetMoon20201127/*ome.tiff')
filelist.sort()

all = []
for ch in range(3):

    all.append(dask.array.concatenate(
    [time_stack(filename,ch=ch) for filename in filelist]
    , axis=3))
all = dask.array.stack(all,axis=1)

%gui qt
# napari.view_image(all[:,0,:,:,:])
v = napari.Viewer(show=False)
v.add_image(all[:,0,:,:,:], contrast_limits=[0,5000],
            blending='additive',
            colormap='green',
            name='tetMoon-gp41GFP',#, is_pyramid=False
                 )
v.add_image(all[:,1,:,:,:], contrast_limits=[0,5000],
        blending='additive',
        colormap='red',
        name='IN-mRuby3',#, is_pyramid=False
             )
v.add_image(all[:,2,:,:,:], contrast_limits=[0,5000],
        blending='additive',
        colormap='blue',
        name='Nucspot650',#, is_pyramid=False
             )
v.show()

Basically I to a time-stack, and then add each channel to it. We can probably tune it to give the channel in the "bundler" as 'cyx' and avoid the last step.

GenevieveBuckley commented 3 years ago

Glad to hear it, that's good news 😄

joaomamede commented 3 years ago

pims.bioformats.BioformatsReader is crazy slow. I needed to use pims.TiffStack_tifffile (same as open). Problem is that the reader does not handle the shape of the .ome.tiff TiffFile.imread shape is 143,3,2044,2048 but pims can't do it right. even after reshape to the right shape it does channel hopping. For reference if someone want to implement this. I had to use the stepping numpy feature. first:end:step.

Note: the iterator and "bundler" are useless with tiffStack opener as the reader is sequential and doesn't allow multi dimensions.

import pims
import numpy as np
import dask
import dask.array
import warnings

def initialize_reader(fn,iterator='t',bundler='cyx', ch =0, **kwargs):
    """
    Read File and returns a pims object while allowing to set iterator and
    frame shape output

    Parameters
    ----------
    fn : str
        A string with one or multiple filenames, read pims multireader
        for details.
    iterator: str
        slice to iterate, defaults to time 't'.
    bundler: str
        slice shape to output, defaults to 'yx'.
    ch:  int
        only output a one channel as default, defaults to 0.

    Returns
    -------
    array : pims.bioformats.BioformatsReader
        A pims reader to access the contents of all image files in
        the predefined channel
    """

    # reader = pims.open(fn)
    reader = pims.TiffStack_tifffile(fn)
    # BioformatsReader(filename, meta=True, java_memory='512m', read_mode='auto', series=0)
    # reader = pims.bioformats.BioformatsReader(fn,meta=False,java_memory='2048m')
    # reader.iter_axes = iterator  # 't' is the default already
    # reader.bundle_axes = bundler
    print(reader)
    # reader.default_coords['c'] = ch
    return reader

def _read_frame(pims_reader,i, arrayfunc=np.asanyarray,**kwargs):
    """
    Read File and returns a pims object while allowing to set
    iterator and frame shape output.

    Parameters
    ----------
    pims_reader : object
        Pims reader object
    i: int
        the file coordinate to output as selected by the initializer
        iterator and default channel

    Returns
    -------
    array : pims.frame.Frame
        Array with the data in the reader current
        shape and default_coords.
    """
    return arrayfunc(pims_reader[i])

def time_stack(fn,nframes=1, ch=0, iterator='t',bundler='cyx',arraytype="numpy",**kwargs):
    if arraytype == "numpy":
        arrayfunc = np.asanyarray
    elif arraytype == "cupy":   # pragma: no cover
        import cupy
        arrayfunc = cupy.asanyarray

    # type(test)
    reader = initialize_reader(fn,
            ch=ch,iterator=iterator,bundler=bundler
            )

    shape = (len(reader),) + reader.frame_shape
    # shape = (3,143,2044,2048)
    dtype = np.dtype(reader.pixel_type)
    print('Shapy',shape)

    if nframes == -1:
        nframes = shape[0]

    if nframes > shape[0]:
        warnings.warn(
            "`nframes` larger than number of frames in file."
            " Will truncate to number of frames in file.",
            RuntimeWarning
        )
    elif shape[0] % nframes != 0:
        warnings.warn(
            "`nframes` does not nicely divide number of frames in file."
            " Last chunk will contain the remainder.",
            RuntimeWarning
        )

    import itertools
    lower_iter, upper_iter = itertools.tee(itertools.chain(
        range(0, shape[0], nframes),
        [shape[0]]
    ))
    next(upper_iter)

    a = []
    # print(type(a))
    for i, j in zip(lower_iter, upper_iter):

        a.append( dask.array.from_delayed(
            dask.delayed(_read_frame)(reader,slice(i,j), arrayfunc=arrayfunc),
            (j - i,) + shape[1:],
            dtype,
            meta=arrayfunc([])
        ))

    a = dask.array.stack(a)
    print(a.shape)
    # a = a.reshape(143,3,2044, 2048)
    return a

##MAIN HERE
import glob

filelist = glob.glob('/home/jmamede/Data/tet/tetMoon20201127/*ome.tiff')
filelist.sort()
all = []
filename = filelist[0]

row1 = dask.array.concatenate(
    [time_stack(filename,bundler='cyx') for filename in filelist[0:6]]
, axis=3)
row1 = row1.reshape(429,2044,8192)
row2 = dask.array.concatenate(
    [time_stack(filename,bundler='cyx') for filename in filelist[7:12]]
, axis=3)
row2 = row2.reshape(429,2044,8192)
row3= dask.array.concatenate(
    [time_stack(filename,bundler='cyx') for filename in filelist[13:18]]
, axis=3)
row3 = row3.reshape(429,2044,8192)
row4 = dask.array.concatenate(
    [time_stack(filename,bundler='cyx') for filename in filelist[19,24]]
, axis=3)
row4 = row4.reshape(429,2044,8192)
row5 = dask.array.concatenate(
    [time_stack(filename,bundler='cyx') for filename in filelist[25,30]]
, axis=3)
row5 = row5.reshape(429,2044,8192)

all = dask.array.concatenate([row1,row2,row3,row4,row5], axis=1)

import napari
%gui qt

# napari.view_image(all[:,0,:,:,:])
v = napari.Viewer(show=False)
v.add_image(all[0:-3:3,:,:,], contrast_limits=[0,500],
            blending='additive',
            colormap='green',
            name='tetMoon-gp41GFP',#, is_pyramid=False
                 )
v.add_image(all[1:-2:3,:,:], contrast_limits=[0,500],
        blending='additive',
        colormap='red',
        name='IN-mRuby3',#, is_pyramid=False
             )
v.add_image(all[2:-1:3,:,:], contrast_limits=[0,1500],
        blending='additive',
        colormap='blue',
        name='Nucspot650',#, is_pyramid=False
             )
v.show()
m-albert commented 3 years ago

Hey @joaomamede I just saw this!

I agree with you that the repeated calls to pims.open create quite a bit of overhead (I proposed an improvement here: https://github.com/dask/dask-image/pull/182). However the problem I see with passing a reader object to each call of _read_frame (and therefore to each dask task) is that, in my understanding, the reader object in case of having many files can be very large (as it contains a list of the individual files). If you have a sequence e.g. of 10000 files, which is not uncommon, this would mean quite some communication overhead between the scheduler and workers.

joaomamede commented 3 years ago

I understand. In my case I only split the files in visit points. (each file has t, z,c, y, x)

I got this to work a bit better, with a new function to output a time stack, after my previous posts. I'll try to share soon.

Thanks for dask.

GenevieveBuckley commented 3 years ago

@joaomamede I've just merged https://github.com/dask/dask-image/pull/182 from @m-albert

Does this work for your use case? You can install dask-image from the main branch with pip install git+https://github.com/dask/dask-image.git to try it out.

joaomamede commented 3 years ago

I will eventually try, feel free to close if you think it was fixed. I wrote a new function to do my bit, based on the one I posted above.