Closed tcompa closed 2 years ago
With @mfranzon, we came up with this minimal-working example that seems to solve the issue.
import dask.array as da
import numpy as np
import time
import sys
def heavy_cellpose_function(block, block_info=None):
return np.random.randint(0, 10000, size=(1, 2000, 2000)).astype(np.uint16)
x = da.zeros((20, 10*2000, 10*2000), chunks=(1, 2000, 2000), dtype=np.uint16)
mask = x.map_blocks(heavy_cellpose_function, chunks=(1, 2000, 2000), meta=np.array((), dtype=np.uint16))
print(mask)
print(mask.size, mask.itemsize, sys.getsizeof(mask))
print()
level0 = mask.to_zarr("mask.zarr", return_stored=True, component="0")
print(level0)
print(level0.size, level0.itemsize, sys.getsizeof(level0))
print()
level1 = da.coarsen(np.mean, level0, {0: 2}).astype(np.uint16)
print(level1)
print(level1.size, level1.itemsize, sys.getsizeof(level1))
print()
level1.to_zarr("mask.zarr", component="1")
print(level1)
print(level1.size, level1.itemsize, sys.getsizeof(level1))
print()
This writes a large array of shape (20, 10*2000, 10*2000)
(roughly corresponding to the labels for a single well) to disk, but it never uses more than ~1GB of RAM.
Ah, that would explain why cellpose is still fairly slow. This sounds very interesting. Making sure cellpose only gets computed once certainly is important!
Also, if you're in the process of changing the pyramid building for the labels: The dask.coarsen with np.mean does not scale well. The full-res pyramid looks fine, but lower res pyramids then have very weird edges.
Here's a zoom-in when looking at labels zoomed out:
The problem is: When coarsening a label image with mean, the edge of nuclei are averaged between the label value & background (0).
We can investigate a very nice coarsening approach. But for the moment, one thing that would already solve this is using np.max
as the aggregation function, always using the max label to represent a pixel. That would maintain constant label values along the pyramid and avoid any combination of foreground & background.
Can you implement this for the label coarsening when applying the above changes?
Ah, that would explain why cellpose is still fairly slow. This sounds very interesting. Making sure cellpose only gets computed once certainly is important!
Not really. This issue concerns our local changes (where we try to use dask's map_blocks
), while the committed version scans sequentially through all sites.
A typical single-site runtime is this one:
Now running cellpose
column_data: <class 'numpy.ndarray'>, (10, 2160, 2560)
End, dtype=uint16 shape=(10, 2160, 2560)
Elapsed: 234.8875 seconds
If you have a feeling that this is slow, let's look more carefully into it.
one thing that would already solve this is using
np.max
as the aggregation function, always using the max label to represent a pixel. That would maintain constant label values along the pyramid and avoid any combination of foreground & background.Can you implement this for the label coarsening when applying the above changes?
Sure
Hmm, good point. 4 minutes per site is actually fine speed-wise for cellpose I'd say from my experience. I was more a bit surprised that the 72 site (9x8) example took some 5 hours. But actually, if it just runs them sequentially, that does make sense (72*4 min = 4.8h). So good that you're testing the parallel running approach here then :)
Hmm, good point. 4 minutes per site is actually fine speed-wise for cellpose I'd say from my experience. I was more a bit surprised that the 72 site (9x8) example took some 5 hours. But actually, if it just runs them sequentially, that does make sense (72*4 min = 4.8h). So good that you're testing the parallel running approach here then :)
Let's continue under https://github.com/fractal-analytics-platform/fractal/issues/64#issuecomment-1177514846
Slightly improved example (more similar to what we actually aim at):
import dask.array as da
import numpy as np
import sys
sys.path.append("/data/homes/fractal/mwe_fractal/fractal/tasks")
from lib_pyramid_creation import create_pyramid_3D
def heavy_cellpose_function(block, block_info=None):
print(f"Now running heavy cellpose function for chunk {block_info[None]['chunk-location']}")
return np.random.randint(0, 10000, size=(1, 2000, 2000)).astype(np.uint16)
x = da.zeros((10, 10*2000, 10*2000), chunks=(1, 2000, 2000), dtype=np.uint16)
mask = x.map_blocks(heavy_cellpose_function, chunks=(1, 2000, 2000), meta=np.array((), dtype=np.uint16))
level0 = mask.to_zarr("mask.zarr", return_stored=True, component="0")
num_levels = 5
pyramid = create_pyramid_3D(level0, coarsening_xy=2, num_levels=num_levels)
for ind_level in range(1, num_levels):
pyramid[ind_level].to_zarr("mask.zarr", component=f"{ind_level}")
Note that this issue applies to illumination_correction
as well (almost sure), and perhaps even to yokogawa_to_zarr
(less likely). To be checked.
EDIT:
I confirm that illumination_correction
calls the correct
function too many times. For a dataset with a single 2x2 well, 3 channels, and 10 Z planes, the expected number of calls is 4*3*10=120
but I observe 600 calls, where the additional factor of 5 comes from the 5 levels.
As of https://github.com/fractal-analytics-platform/fractal/commit/9ba643ec6c1cba688a596a37efad546601c53a6d there is a new write_pyramid
function that in principle fixes the bug for illumination_correction
. Note that illumination_correction
is trickier than image_labeling
, because it also includes overwriting a zarr file, and then two tricky points are mixed (pyramids leading to repeated executions, and dask failing to correctly apply the from_zarr/map_blocks/to_zarr
sequence).
More tests are underway (I'm also taking this opportunity to also write some actual unit tests for some of these tasks), before I'm sure that everything works.
To reiterate the goals, we aim at a situation where:
image_labeling
but also relevant for illumination_correction
).0->1
and 1->2
). This is a minor issue, compared to the previous one, but let's make sure that it's not there.Something else to check/fix: is yokogawa_to_zarr
affected by a similar problem?
Take-home message: lazy computations are tricky! Let's make sure that whenever to_zarr
is in the final part of a task, it gets (when appropriate) the compute=True
option.
Really valuable learnings, thanks for working through this & summarizing the take-aways @tcompa !
Now testing the full workflow, after several changes here and there (both in the pyramid-writing function and in the two labeling/correction tasks), and after unit tests for both tasks are in-place.
I am afraid we are hitting memory problems again (some tasks in the monitoring seem to be using much more CPU memory than expected), I will look at it again. Hopefully the new unit tests should allow a smoother debugging.
The issue of multiple calls within illumination correction or within per-FOV labeling has already been addressed, and it is now tested in test_unit_illumination_correction.py
and test_unit_image_labeling.py
.
The only missing part is the unit test of yokogawa_to_zarr
.
The unit test added in https://github.com/fractal-analytics-platform/fractal/commit/e3a029eafef8429b420fd8fcb57c3ed13c6873bf confirms that the bug also affects yokogawa_to_zarr
.
The fix in https://github.com/fractal-analytics-platform/fractal/commit/32a9a92bbe302984b4dcbeac9e55f3d5d07cbc10 solves the issue (i.e. the number of imread
calls matches with the number of images, up to the few dummy initial calls), but we still need some more thorough check of the output before closing.
By the way, this also solves https://github.com/fractal-analytics-platform/fractal/issues/53#issuecomment-1140976533, by always using the same write_pyramid
function everywhere.
At the moment, the pyramid-creation function returns a list of delayed arrays, for all levels, and they all include the full computation graph (including, for instance, the cellpose segmentation part). This means that everything is repeated
num_levels
times, which is obviously not an option since the beginning of the calculation may be very heavy.