fractal-analytics-platform / fractal-client

Command-line client for Fractal
https://fractal-analytics-platform.github.io/fractal-client
BSD 3-Clause "New" or "Revised" License
45 stars 1 forks source link

Move create_pyramid to write_pyramid, to avoid repeated calls of heavy functions #97

Closed tcompa closed 2 years ago

tcompa commented 2 years ago

At the moment, the pyramid-creation function returns a list of delayed arrays, for all levels, and they all include the full computation graph (including, for instance, the cellpose segmentation part). This means that everything is repeated num_levels times, which is obviously not an option since the beginning of the calculation may be very heavy.

tcompa commented 2 years ago

With @mfranzon, we came up with this minimal-working example that seems to solve the issue.

import dask.array as da
import numpy as np
import time
import sys

def heavy_cellpose_function(block, block_info=None):
    return np.random.randint(0, 10000, size=(1, 2000, 2000)).astype(np.uint16)

x = da.zeros((20, 10*2000, 10*2000), chunks=(1, 2000, 2000), dtype=np.uint16)

mask = x.map_blocks(heavy_cellpose_function, chunks=(1, 2000, 2000), meta=np.array((), dtype=np.uint16))

print(mask)
print(mask.size, mask.itemsize, sys.getsizeof(mask))
print()

level0 = mask.to_zarr("mask.zarr", return_stored=True, component="0")
print(level0)
print(level0.size, level0.itemsize, sys.getsizeof(level0))
print()

level1 = da.coarsen(np.mean, level0, {0: 2}).astype(np.uint16)
print(level1)
print(level1.size, level1.itemsize, sys.getsizeof(level1))
print()

level1.to_zarr("mask.zarr", component="1")

print(level1)
print(level1.size, level1.itemsize, sys.getsizeof(level1))
print()

This writes a large array of shape (20, 10*2000, 10*2000) (roughly corresponding to the labels for a single well) to disk, but it never uses more than ~1GB of RAM.

jluethi commented 2 years ago

Ah, that would explain why cellpose is still fairly slow. This sounds very interesting. Making sure cellpose only gets computed once certainly is important!

Also, if you're in the process of changing the pyramid building for the labels: The dask.coarsen with np.mean does not scale well. The full-res pyramid looks fine, but lower res pyramids then have very weird edges.

Here's a zoom-in when looking at labels zoomed out: LabelPyramidIssue

The problem is: When coarsening a label image with mean, the edge of nuclei are averaged between the label value & background (0).

We can investigate a very nice coarsening approach. But for the moment, one thing that would already solve this is using np.max as the aggregation function, always using the max label to represent a pixel. That would maintain constant label values along the pyramid and avoid any combination of foreground & background.

Can you implement this for the label coarsening when applying the above changes?

tcompa commented 2 years ago

Ah, that would explain why cellpose is still fairly slow. This sounds very interesting. Making sure cellpose only gets computed once certainly is important!

Not really. This issue concerns our local changes (where we try to use dask's map_blocks), while the committed version scans sequentially through all sites.

A typical single-site runtime is this one:

Now running cellpose
column_data: <class 'numpy.ndarray'>, (10, 2160, 2560)
End, dtype=uint16 shape=(10, 2160, 2560)
Elapsed: 234.8875 seconds

If you have a feeling that this is slow, let's look more carefully into it.

tcompa commented 2 years ago

one thing that would already solve this is using np.max as the aggregation function, always using the max label to represent a pixel. That would maintain constant label values along the pyramid and avoid any combination of foreground & background.

Can you implement this for the label coarsening when applying the above changes?

Sure

jluethi commented 2 years ago

Hmm, good point. 4 minutes per site is actually fine speed-wise for cellpose I'd say from my experience. I was more a bit surprised that the 72 site (9x8) example took some 5 hours. But actually, if it just runs them sequentially, that does make sense (72*4 min = 4.8h). So good that you're testing the parallel running approach here then :)

tcompa commented 2 years ago

Hmm, good point. 4 minutes per site is actually fine speed-wise for cellpose I'd say from my experience. I was more a bit surprised that the 72 site (9x8) example took some 5 hours. But actually, if it just runs them sequentially, that does make sense (72*4 min = 4.8h). So good that you're testing the parallel running approach here then :)

Let's continue under https://github.com/fractal-analytics-platform/fractal/issues/64#issuecomment-1177514846

tcompa commented 2 years ago

Slightly improved example (more similar to what we actually aim at):

import dask.array as da
import numpy as np
import sys

sys.path.append("/data/homes/fractal/mwe_fractal/fractal/tasks")
from lib_pyramid_creation import create_pyramid_3D

def heavy_cellpose_function(block, block_info=None):
    print(f"Now running heavy cellpose function for chunk {block_info[None]['chunk-location']}")
    return np.random.randint(0, 10000, size=(1, 2000, 2000)).astype(np.uint16)

x = da.zeros((10, 10*2000, 10*2000), chunks=(1, 2000, 2000), dtype=np.uint16)

mask = x.map_blocks(heavy_cellpose_function, chunks=(1, 2000, 2000), meta=np.array((), dtype=np.uint16))

level0 = mask.to_zarr("mask.zarr", return_stored=True, component="0")

num_levels = 5
pyramid = create_pyramid_3D(level0, coarsening_xy=2, num_levels=num_levels)

for ind_level in range(1, num_levels):
    pyramid[ind_level].to_zarr("mask.zarr", component=f"{ind_level}")
tcompa commented 2 years ago

Note that this issue applies to illumination_correction as well (almost sure), and perhaps even to yokogawa_to_zarr (less likely). To be checked.

EDIT: I confirm that illumination_correction calls the correct function too many times. For a dataset with a single 2x2 well, 3 channels, and 10 Z planes, the expected number of calls is 4*3*10=120 but I observe 600 calls, where the additional factor of 5 comes from the 5 levels.

tcompa commented 2 years ago

As of https://github.com/fractal-analytics-platform/fractal/commit/9ba643ec6c1cba688a596a37efad546601c53a6d there is a new write_pyramid function that in principle fixes the bug for illumination_correction. Note that illumination_correction is trickier than image_labeling, because it also includes overwriting a zarr file, and then two tricky points are mixed (pyramids leading to repeated executions, and dask failing to correctly apply the from_zarr/map_blocks/to_zarr sequence).

More tests are underway (I'm also taking this opportunity to also write some actual unit tests for some of these tasks), before I'm sure that everything works.

To reiterate the goals, we aim at a situation where:

  1. Building a pyramid of 5 levels does not trigger 5 executions of the (lazy) code that generates level 0 (this is the current issue, discovered for image_labeling but also relevant for illumination_correction).
  2. Building the 3-rd level of a pyramid does not require re-execution of the previous coarsening operations (the ones for levels 0->1 and 1->2). This is a minor issue, compared to the previous one, but let's make sure that it's not there.
  3. Overwriting a zarr array is still possible, albeit with our quick&dirty approach of https://github.com/fractal-analytics-platform/fractal/issues/62#issuecomment-1152171893 and https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/53.

Something else to check/fix: is yokogawa_to_zarr affected by a similar problem?

Take-home message: lazy computations are tricky! Let's make sure that whenever to_zarr is in the final part of a task, it gets (when appropriate) the compute=True option.

jluethi commented 2 years ago

Really valuable learnings, thanks for working through this & summarizing the take-aways @tcompa !

tcompa commented 2 years ago

Now testing the full workflow, after several changes here and there (both in the pyramid-writing function and in the two labeling/correction tasks), and after unit tests for both tasks are in-place.

I am afraid we are hitting memory problems again (some tasks in the monitoring seem to be using much more CPU memory than expected), I will look at it again. Hopefully the new unit tests should allow a smoother debugging.

tcompa commented 2 years ago

The issue of multiple calls within illumination correction or within per-FOV labeling has already been addressed, and it is now tested in test_unit_illumination_correction.py and test_unit_image_labeling.py.

The only missing part is the unit test of yokogawa_to_zarr.

tcompa commented 2 years ago

The unit test added in https://github.com/fractal-analytics-platform/fractal/commit/e3a029eafef8429b420fd8fcb57c3ed13c6873bf confirms that the bug also affects yokogawa_to_zarr.

tcompa commented 2 years ago

The fix in https://github.com/fractal-analytics-platform/fractal/commit/32a9a92bbe302984b4dcbeac9e55f3d5d07cbc10 solves the issue (i.e. the number of imread calls matches with the number of images, up to the few dummy initial calls), but we still need some more thorough check of the output before closing.

By the way, this also solves https://github.com/fractal-analytics-platform/fractal/issues/53#issuecomment-1140976533, by always using the same write_pyramid function everywhere.

tcompa commented 2 years ago

Fixed with https://github.com/fractal-analytics-platform/fractal/pull/133.