fractal-analytics-platform / fractal-client

Command-line client for Fractal
https://fractal-analytics-platform.github.io/fractal-client
BSD 3-Clause "New" or "Revised" License
45 stars 1 forks source link

Labeling-tasks integration into fractal + (Torch) memory errors #109

Closed tcompa closed 2 years ago

tcompa commented 2 years ago

Integrate image_labeling and image_labeling_whole_well in fractal (for the current working version, not for the incoming server-based one).

tcompa commented 2 years ago

As of 3c1c431d5919c3901d4d88e26160b097b098110a, there are a few working examples of this integration, see for instance examples/UZH_1_well_2x2_sites (single-well) or examples/UZH_4_well_2x2_sites (multi-well).

A parameter file looks like this:

$ cat wf_params_uzh_4_well_2x2_sites.json
{
"workflow_name": "uzh_4_well_2x2_sites",
"dims": [2, 2],
"coarsening_xy": 2,
"coarsening_z": 1,
"num_levels": 5,
"channel_file": "../wf_params_uzh_cardiac_channels.json",
"path_dict_corr": "../wf_params_uzh_cardiac_illumination.json",
"image_labeling": {"coarsening_xy": 2, "labeling_level": 0, "labeling_channel": "A01_C01", "num_threads": 1, "relabeling": 1, "anisotropy": 6.1538, "diameter": 35.0, "cellprob_threshold": 0.0},
"image_labeling_whole_well": {"coarsening_xy": 2, "labeling_level": 2, "labeling_channel": "A01_C01", "diameter_level0": 35.0, "cellprob_threshold": 0.0}
}

(note: the structure got somewhat similar to the idea of a pipeline JSON file, but this is not part of that work - which will only take place in the incoming server-based version)

These examples include the whole set of current tasks (including illumination correction, per-FOV labeling, MIP creation, per-well MIP labeling). Note that the whole-well labeling only takes place on MIP images (this is hardcoded, at the moment).

For the per-FOV labeling in the multi-well case, parsl opens as many GPU jobs as possible (ideally as many as the number of wells).

In the running folder, there are some log files named like LOG_image_labeling_B_03 or LOG_image_labeling_B_03_whole_well, which at the moment are only for debugging.

tcompa commented 2 years ago

Let's run a couple more tests, including those with larger (9x8) wells, before closing this issue.

tcompa commented 2 years ago

A quick comment while running multi-well workflows: unsurprisingly, they seem to take a bit longer than what one would expect from similar single-well examples (this also applies to labeling tasks), which is likely related to https://github.com/fractal-analytics-platform/fractal/issues/57.

EDIT: what I wrote above is likely true, but in the specific example I was looking at (UZH_4_well_2x2_sites) the issue was rather due to the number of Z planes (19, 42, 29, 42, for the four wells).

tcompa commented 2 years ago

The following Fractal examples ran through

A run with 10 5x5 wells (and 19 Z planes) lead to a memory error, when run with num_threads=2. I will retry with num_threads=1.

tcompa commented 2 years ago

There's still something wrong with the 10-wells example, which successfully completes only 8 out of 10 labeling executions:

LOG_image_labeling_B_09: 25/25 completed
LOG_image_labeling_B_11: 25/25 completed
LOG_image_labeling_C_08: 25/25 completed
LOG_image_labeling_C_10: 25/25 completed
LOG_image_labeling_D_09: 25/25 completed
LOG_image_labeling_D_11: 25/25 completed
LOG_image_labeling_E_08: 25/25 completed
LOG_image_labeling_E_10: 25/25 completed
LOG_image_labeling_F_09: 9/25 completed
LOG_image_labeling_F_11: 18/25 completed

Total: 227/250 completed

The parsl error is of the ManagerLost kind, which is not very informative. Somewhere in the logs, I observe:

$ cat runinfo/000/submit_scripts/parsl.slurm.1658214962.7334816.submit.stderr

0: slurmstepd: Step 9259927.0 exceeded virtual memory limit (64097388 > 63543705), being killed
0: slurmstepd: *** STEP 9259927.0 ON pelkmanslab-slurm-worker-040 CANCELLED AT 2022-07-19T09:41:24 ***
srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
slurmstepd: Exceeded job memory limit
slurmstepd: Exceeded job memory limit
slurmstepd: *** JOB 9259927 ON pelkmanslab-slurm-worker-040 CANCELLED AT 2022-07-19T09:41:24 ***

which points at a very explicit virtual-memory error (reaching the 64 G available). Is this related to https://github.com/fractal-analytics-platform/fractal/issues/33?

To be verified.

tcompa commented 2 years ago

Addendum to the last comment: The two failing jobs are the two that started last (there are only 8 GPU nodes available, for 10 tasks). This is hardly a coincidence.

tcompa commented 2 years ago

One more piece of information, trying to pinpoint the source of the memory error when per-FOV labeling 10 wells.

To test the multi-well case and (temporarily) avoid memory errors, I ran a workflow where per-FOV segmentation of 10 5x5 wells takes place at level 1. This is allowed by the new ROI-based labeling (see https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/19). I also reduced max_blocks_gpu to 1, so that all 10 labeling tasks run (one by one) on the same node.

Here are the CPU and memory traces (the CPU trace is there just as a reference, as it is quite trivial): Screenshot from 2022-07-27 08-47-38 Screenshot from 2022-07-27 08-47-51

The first two blocks of execution (up to time ~1800 s) correspond to other (non-labeling) tasks, and then there are 10 blocks of per-FOV labeling (one for each well). The annoying feature is the build up of memory usage between the 1st and 3rd segmentation tasks, from ~5 G to ~16 G. This is unexpected, as each task should use approximately the same memory (notice that number of Z planes is constant across wells, in this dataset).

Thus the actual question becomes something like: Why is memory accumulating when rerunning a task on the same node? Is there some garbage-collection issue at the end of the tasks (aka parsl's python apps)?

tcompa commented 2 years ago

Quick update after discussing with @mfranzon

We tend to think that this is related to cellpose (and probably torch) not freeing up memory, rather than parsl. Can we explicitly force a garbage collection / cache clean-up after a task? To be tested.

Possibly related: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530 https://forum.image.sc/t/napari-cellpose-out-of-memory/60742/3

jluethi commented 2 years ago

Sounds like a good assumption to test, because we've only seen this issue come up with labeling jobs after all :)

tcompa commented 2 years ago

Quick update:

For debugging, we now import the Cellpose model within the segment_FOV function, that is, once per FOV (instead of passing around the same model object for all FOVs). That is, the block

    use_gpu = core.use_gpu()
    model = models.Cellpose(gpu=use_gpu, model_type=model_type)

is always re-executed. What is unexpected, and possibly related to the memory issue, is the timing of the first line (core.use_gpu()). When first running on a given node, it takes a couple of seconds. When running again on the same node, it takes something like 0.001 seconds, meaning that some trace/cache of the torch initialization is preserved. Digging a bit further, the use_gpu function maps to these simple and apparently harmless lines:

        device = torch.device('cuda:' + str(gpu_number))
        _ = torch.zeros([1, 2, 3]).to(device)

It would be great to have a torch.reset() function (we are not the only ones who think something similar), but at the moment we only have torch.cuda.empty_cache(), which doesn't seem to work.

This update points towards a likely problem, but not towards an obvious solution.

Notice that we now tend to exclude that the problem is due to some attributes of Cellpose model being kept in memory, since the model is now always re-initialized within segment_FOV.

jluethi commented 2 years ago

Hmm, and with this new model initialization within each segment_FOV call, we still run into the same memory issue?

Also, is a torch.reset() something we could safely do if multiple jobs run on parallel on the same GPU?

tcompa commented 2 years ago

Hmm, and with this new model initialization within each segment_FOV call, we still run into the same memory issue?

Our understanding (by now) is that the Cellpose model initialization is irrelevant, and the problem comes from torch. And indeed the multiple initializations of the model apparently do not solve memory issue.

Also, is a torch.reset() something we could safely do if multiple jobs run on parallel on the same GPU?

This is not really an issue, since torch does not have a reset() function. It's more something we'd like to have, to free the "cached" (?) memory.

jluethi commented 2 years ago

Can we reproduce this in a synthetic setup by e.g. just running 2 or 10 cellpose models sequentially in a single script?

Could be that it's a new issue or related to a specific cellpose/torch/gpu. But I've ran scripts that used cellpose for model inference in the past on 1000s of 3D images, each of them close to the memory limit. And that was all sequentially on a single GPU. So would be kind of surprised if there is a very general issue with cellpose or torch memory handling...

tcompa commented 2 years ago

Can we reproduce this in a synthetic setup by e.g. just running 2 or 10 cellpose models sequentially in a single script?

Yes.

By now I'm starting with the original image_labeling task (with a global model definition), and I only add a gc.collect() after each call (which seems to do nothing).

Here is the script that runs four wells sequentially:

import sys
import time
import os
import shutil
from fractal.tasks.image_labeling import image_labeling
import gc

zarrurl = "/data/active/fractal/tests/Temporary_data_UZH_4_well_2x2_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/"

wells = ["B/03/0/", "B/05/0/", "C/04/0/", "D/05/0/"]
print(wells)
print()

for well in wells:
    label_dir = zarrurl + well + "labels"
    if os.path.isdir(label_dir):
        shutil.rmtree(label_dir)
print("Cleared label folders")
print()

with open("times.dat", "w") as f:
    t0 = time.perf_counter()
    for well in wells:
        print(well)
        sys.stdout.flush()
        image_labeling(
                zarrurl + well,
                coarsening_xy=2,
                labeling_level=2,
                labeling_channel="A01_C01",
                chl_list=["A01_C01"],
                num_threads=4,
                relabeling=1,
                )
        t1 = time.perf_counter()

        f.write(f"{t1-t0}\n")
        f.flush()

        gc.collect()

And here is the memory trace (also using https://bbengfort.github.io/2020/07/read-mprofile-into-pandas), where something is clearly accumulating between different image_labeling calls (especially between the first and second ones): fig_memory

tcompa commented 2 years ago

By the way: this last comment confirms that parsl is not involved in this memory issue (since the script was executed directly via SLURM, with parsl never appearing).

jluethi commented 2 years ago

Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)?

Reason I'm asking: if torch does some optimization in the background and decides when to clear what memory, I wouldn't care about it, as long as it does it well enough to avoid out-of-memory errors. So I'm not sure how concerning the fact that some memory is accumulating between runs is, unless it is not freed up when needed.

What is your expectation here: Does torch do any fancy optimization for when to free up which memory? Because if it was an actual memory leak, I would expect linear accumulation.

And for my understanding: we are always talking about CPU memory here, right? Or do we go into GPU memory being an issue?

tcompa commented 2 years ago

Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)?

Sure, let's try. My guess is that this is exactly the memory error of https://github.com/fractal-analytics-platform/fractal/issues/109#issuecomment-1188958357, but let's check it explicitly.

What is your expectation here: Does torch do any fancy optimization for when to free up which memory? Because if it was an actual memory leak, I would expect linear accumulation.

It's clearly not linear, see the saturation to a plateau in the comment above: https://user-images.githubusercontent.com/3862206/181180402-409d000f-972c-4e8f-85eb-f05b0f44904b.png.

And for my understanding: we are always talking about CPU memory here, right? Or do we go into GPU memory being an issue?

Yes, this is all standard CPU memory. GPU memory errors may appear, but (in our experience) only when running many (e.g. 10) simultaneous Cellpose calculations on the same GPU.

jluethi commented 2 years ago

Hmm, ok. Let's see if that test can explicitly push it over the limit.

Also, I wonder whether there are some torch parameters that would tell it about available memory. It does seem to handle memory cleanup sometimes, but maybe it's optimized for classical systems where it could go into swap a little bit? Maybe there are torch parameters we could set to make it more aggressive in CPU memory cleanup?

tcompa commented 2 years ago

Also, I wonder whether there are some torch parameters that would tell it about available memory. It does seem to handle memory cleanup sometimes, but maybe it's optimized for classical systems where it could go into swap a little bit? Maybe there are torch parameters we could set to make it more aggressive in CPU memory cleanup?

The only related option we could find is no_grad, but this introduces an explicit change of the function and we are not able to say whether it modifies Cellpose behavior. Is there a trivial answer?

(btw, we have not tested it yet)

Other than that, we'd be glad to test other relevant torch options, if you discover any.

jluethi commented 2 years ago

Ok, just googling a bit. Have you had a look at setting LRU_CACHE_CAPACITY=1 as an environment variable? Or set to something different. There's a few interesting discussions on memory usage in pytorch here and here that come to that conclusion.

e.g. testing having this included would be interesting:

import os
os.environ["LRU_CACHE_CAPACITY"] = "1"
tcompa commented 2 years ago

Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)?

Sure, let's try. My guess is that this is exactly the memory error of #109 (comment), but let's check it explicitly.

I confirm what I said: the example we are looking at does yield the expected memory error (AKA there is no smart optimization of memory usage by torch, but rather a memory accumulation as several image_labeling tasks are called).

Here is the memory trace (details in the plot title), and the memory error appears during processing of the third well. fig_memory

Just as a reference, the detailed traceback is

B/09/0/
B/11/0/
C/08/0/
Traceback (most recent call last):
  File "Many_segmentations.py", line 24, in <module>
    image_labeling(
  File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/image_labeling.py", line 291, in image_labeling
    write_pyramid(
  File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/lib_pyramid_creation.py", line 69, in write_pyramid
    level0 = to_zarr_custom(
  File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/lib_to_zarr_custom.py", line 64, in to_zarr_custom
    output = array.to_zarr(
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 2828, in to_zarr
    return to_zarr(self, *args, **kwargs)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 3591, in to_zarr
    return arr.store(z, lock=False, compute=compute, return_stored=return_stored)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 1752, in store
    r = store([self], [target], **kwargs)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 1214, in store
    store_dlyds = persist(*store_dlyds, **kwargs)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/base.py", line 904, in persist
    results = schedule(dsk, keys, **kwargs)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/threaded.py", line 89, in get
    results = get_async(
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 511, in get_async
    raise_exception(exc, tb)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 319, in reraise
    raise exc
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 224, in execute_task
    result = _execute_task(task, data)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/core.py", line 119, in _execute_task
    return func(*(_execute_task(a, cache) for a in args))
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/utils.py", line 41, in apply
    return func(*args, **kwargs)
  File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/image_labeling.py", line 57, in segment_FOV
    mask, flows, styles, diams = model.eval(
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 227, in eval
    masks, flows, styles = self.cp.eval(x, 
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 536, in eval
    masks, styles, dP, cellprob, p = self._run_cp(x, 
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 625, in _run_cp
    masks, p = dynamics.compute_masks(dP, cellprob, niter=niter, 
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/dynamics.py", line 718, in compute_masks
    mask = get_masks(p, iscell=cp_mask)
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/dynamics.py", line 636, in get_masks
    h,_ = np.histogramdd(tuple(pflows), bins=edges)
  File "<__array_function__ internals>", line 180, in histogramdd
  File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/numpy/lib/histograms.py", line 1101, in histogramdd
    hist = hist.astype(float, casting='safe')
numpy.core._exceptions.MemoryError: Unable to allocate 2.60 GiB for an array with shape (61, 2202, 2602) and data type float64
mprof: Sampling memory every 0.1s
running new process

but that's not relevant as it could happen on any other line of Cellpose code.

tcompa commented 2 years ago

Ok, just googling a bit. Have you had a look at setting LRU_CACHE_CAPACITY=1 as an environment variable? Or set to something different. There's a few interesting discussions on memory usage in pytorch here and here that come to that conclusion.

e.g. testing having this included would be interesting:

import os
os.environ["LRU_CACHE_CAPACITY"] = "1"

Thanks for the links, but @mfranzon noticed that this should be already fixed in pytorch>1.5 (and we are on 1.12.0+cu102). Too bad ;)

And indeed by adding these two lines we find a behavior which is fully compatible with https://github.com/fractal-analytics-platform/fractal/issues/109#issuecomment-1197929744 (notice that reproducing the memory error each time takes too long, so we test on the smaller dataset).

fig_memory

jluethi commented 2 years ago

Hmm, thanks for the thorough checks!

Does it also happen when calling the pure cellpose model inference on e.g. some synthetic test data (without our image_labeling wrapper that adds some complexity and some dask logic)? If we can make such a test case and still run out of memory, we could then report this to the cellpose repo.

tcompa commented 2 years ago

WARNING: I keep forgetting that the dataset with 4 2x2 wells has varying number of Z planes, which makes the interpretation of memory traces much less obvious. This information is also wrong in the plot titles (where I wrote "10 Z planes").

Let's focus only on the case with 10 5x5 wells (and constant number of Z levels).

tcompa commented 2 years ago

Does it also happen when calling the pure cellpose model inference on e.g. some synthetic test data (without our image_labeling wrapper that adds some complexity and some dask logic)? If we can make such a test case and still run out of memory, we could then report this to the cellpose repo.

We are trying to at least remove the image_labeling/dask logic, but keeping actual data [early tests with artificial data (like np.zeros) were not so informative, but we can try again].

jluethi commented 2 years ago

I keep forgetting that the dataset with 4 2x2 wells has varying number of Z planes, which makes the interpretation of memory traces much less obvious. This information is also wrong in the plot titles (where I wrote "10 Z planes").

Ah, but wouldn't that potentially explain the out of memory problem? Do we know whether we run out of memory in cases with "too many" Z planes (=> where the user should choose a lower pyramid level)? And yes, we'll certainly need a better (small) test case then to verify whether memory load increases. Does the 10 well, 5x5 work for that or is something smaller required?

tcompa commented 2 years ago

I keep forgetting that the dataset with 4 2x2 wells has varying number of Z planes, which makes the interpretation of memory traces much less obvious. This information is also wrong in the plot titles (where I wrote "10 Z planes").

Ah, but wouldn't that potentially explain the out of memory problem?

The memory error is for a dataset with constant number of Z planes (10), and it's not affected by this last remark. Other memory traces shown in this issue (those for the 2x2 wells) are the ones affected - not because they are wrong, but because their interpretation is incomplete.

Does the 10 well, 5x5 work for that or is something smaller required?

We are now trying with the 10-5x5 dataset, working at the per-well level in 3D and at pyramid level 3 (this corresponds to calling Cellpose on (19, 1350, 1600) array). Let's see if running times is under control, or we will try level 4.

tcompa commented 2 years ago

We tried to reduce complexity as far as possible, while keeping the problematic behavior there. We prepared a mock of the image_labeling task, which has only its core components (read, segment, write) in two flavours:

  1. More similar to our task: segmentation function is applied as a dask delayed function (with num_threads=2);
  2. Even simpler: segmentation function is applied sequentially over FOVs.

Code:

With delayed function ```python import os import shutil import sys import itertools import numpy as np import time import dask import dask.array as da from cellpose import core from cellpose import models from concurrent.futures import ThreadPoolExecutor def fun(FOV_column, model): t1_start = time.perf_counter() print("START: shape =", FOV_column.shape) sys.stdout.flush() mask, flows, styles, diams = model.eval( FOV_column, channels=[0, 0], do_3D=True, net_avg=False, augment=False, diameter=(80.0 / 2**labeling_level), anisotropy=6.0, cellprob_threshold=0.0, ) t1 = time.perf_counter() print(f"END: I found {np.max(mask)} labels, in {t1-t1_start:.3f} seconds") sys.stdout.flush() return mask def image_labeling(well, labeling_level=None, labeling_channel=None, num_threads=None): print(well) use_gpu = core.use_gpu() print("use_gpu:", use_gpu) model = models.Cellpose(gpu=use_gpu, model_type="nuclei") # Load full-well data column = da.from_zarr(zarrurl + well + f"{labeling_level}/")[labeling_channel] output = da.empty(column.shape, chunks=column.chunks, dtype=column.dtype) delayed_fun = dask.delayed(fun) # Select a single FOV for ind_FOV in itertools.product(range(2), repeat=2): # Define FOV indices ix, iy = ind_FOV size_x = 2560 // 2 ** labeling_level size_y = 2160 // 2 ** labeling_level start_x = size_x * ix end_x = size_x * (ix + 1) start_y = size_y * iy end_y = size_y * (iy + 1) # Select input and assign output FOV_column = column[:, start_y:end_y, start_x:end_x] FOV_mask = delayed_fun(FOV_column, model) output[:, start_y:end_y, start_x:end_x] = da.from_delayed(FOV_mask, shape=FOV_column.shape, dtype=FOV_column.dtype) # Remove output file, if needed outzarr = f"/tmp/{well}.zarr" if os.path.isdir(outzarr): shutil.rmtree(outzarr) # Write output (--> trigger execution of delayed functions) with dask.config.set(pool=ThreadPoolExecutor(num_threads)): output.to_zarr(outzarr) print() sys.stdout.flush() root = "/data/active/fractal/tests/" zarrurl = root + "Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/" wells = ["B/09/0/", "B/11/0/", "C/08/0/", "C/10/0/", "D/09/0/", "D/11/0/", "E/08/0/", "E/10/0/", "F/09/0/", "F/11/0/"] f = open("times.dat", "w") t0 = time.perf_counter() num_threads = 2 labeling_level = 0 labeling_channel = 0 for well in wells: image_labeling(well, labeling_level=labeling_level, labeling_channel=labeling_channel, num_threads=num_threads) t1 = time.perf_counter() f.write(f"{t1-t0}\n") f.flush() f.close() ```
With sequential functions ```python import os import shutil import sys import itertools import numpy as np import time import dask.array as da from cellpose import core from cellpose import models def fun(FOV_column, model): t1_start = time.perf_counter() print("START: shape =", FOV_column.shape) sys.stdout.flush() mask, flows, styles, diams = model.eval( FOV_column, channels=[0, 0], do_3D=True, net_avg=False, augment=False, diameter=(80.0 / 2**labeling_level), anisotropy=6.0, cellprob_threshold=0.0, ) t1 = time.perf_counter() print(f"END: I found {np.max(mask)} labels, in {t1-t1_start:.3f} seconds") sys.stdout.flush() return mask def image_labeling(well, labeling_level=None, labeling_channel=None): print(well) use_gpu = core.use_gpu() print("use_gpu:", use_gpu) model = models.Cellpose(gpu=use_gpu, model_type="nuclei") # Load full-well data column = da.from_zarr(zarrurl + well + f"{labeling_level}/")[labeling_channel] output = da.empty(column.shape, chunks=column.chunks, dtype=column.dtype) # Select a single FOV for ind_FOV in itertools.product(range(2), repeat=2): # Define FOV indices ix, iy = ind_FOV size_x = 2560 // 2 ** labeling_level size_y = 2160 // 2 ** labeling_level start_x = size_x * ix end_x = size_x * (ix + 1) start_y = size_y * iy end_y = size_y * (iy + 1) # Select input and assign output FOV_column = column[:, start_y:end_y, start_x:end_x] FOV_mask = fun(FOV_column, model) output[:, start_y:end_y, start_x:end_x] = FOV_mask # Remove output file, if needed outzarr = f"/tmp/{well}_clean.zarr" if os.path.isdir(outzarr): shutil.rmtree(outzarr) # Write output (--> trigger execution of delayed functions) output.to_zarr(outzarr) print() sys.stdout.flush() root = "/data/active/fractal/tests/" zarrurl = root + "Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/" wells = ["B/09/0/", "B/11/0/", "C/08/0/", "C/10/0/", "D/09/0/", "D/11/0/", "E/08/0/", "E/10/0/", "F/09/0/", "F/11/0/"] f = open("times.dat", "w") t0 = time.perf_counter() labeling_level = 0 labeling_channel = 0 for well in wells: image_labeling(well, labeling_level=labeling_level, labeling_channel=labeling_channel, ) t1 = time.perf_counter() f.write(f"{t1-t0}\n") f.flush() f.close() ```

We consider the usual 10-5x5 dataset, but we only segment a 2x2 subset of each 5x5 well. Segmentation is performed at level 0.

The memory trace of these two runs is below. Black lines in the figure are rolling averages, as a guide to the eye. Comments:

  1. With the delayed-based task, memory accumulation is clearly present;
  2. With the task that sequentially scans FOVs, accumulation is less evident but probably still there.

These examples seem robust, but we noticed that working with artificial data (AKA np.zeros) or with much smaller arrays (AKA lower resolution) doesn't always lead to clear indications of memory accumulation - and this slows down further testing. We'll think a bit more about which options we have, before running many more long tests.

fig_memory

A final detail, the size of datasets for different wells (after selecting level 0 and channel 0) is very homogeneous:

$ du -sh -L /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/*/*/*/0/0 | sort -k2
2.6G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/B/09/0/0/0
2.6G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/B/11/0/0/0
2.6G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/C/08/0/0/0
2.6G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/C/10/0/0/0
2.5G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/D/09/0/0/0
2.5G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/D/11/0/0/0
2.6G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/E/08/0/0/0
2.5G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/E/10/0/0/0
2.5G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/F/09/0/0/0
2.5G    /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/F/11/0/0/0
tcompa commented 2 years ago

I just updated the figure in the previous comment. The memory accumulation along the 10-wells example is clear even in the case where cellpose is called sequentially (orange line), without any use of dask.delayed functions (note that dask is still used to load/write data, but that should really be irrelevant).

tcompa commented 2 years ago

Comment by me and @mfranzon: This memory issue is taking a lot of time and effort, but it's likely due to something in external libraries (cellpose and/or torch). Perhaps we should move the discussion to a cellpose issue? (provided we manage to write a reasonably simple test, possibly without our actual data).

jluethi commented 2 years ago

Wow, yeah! I am still surprised with the issue, but now that it doesn't seem to be a dask issue: If we can generate a test case of this happening, I'd be very much in favor of opening a cellpose issue with this! If you need to use some example data from my UZH test set: That data isn't a precious secret, so feel free to use that.

And then that's probably a good point to stop our digging into it. Either there is a fix from the external library (cellpose in this case) or we need to think bigger-picture: How do we run external libraries that may not manage memory well (/or that may have weird dependencies etc.)

jluethi commented 2 years ago

For a simple test case, could you just load the same image repeatedly (e.g. even just as aPNG image using imageio, doesn't need to be Zarr if that makes the test complicated) and loop cellpose over it in a basic for loop, e.g. not even saving the results.

tcompa commented 2 years ago

The discussion should probably continue in https://github.com/MouseLand/cellpose/issues/539.

jluethi commented 2 years ago

Great that we have this escalated now. I'd say work on the segmentation milestone is done then from our side. Let's follow the cellpose discussion to see if there is a good workaround and otherwise think broader about handling libraries with potential memory issues :)

tcompa commented 2 years ago

The labeling task is part of fractal, and remaining issues are unrelated to this one. Closing.