MouseLand / cellpose

a generalist algorithm for cellular segmentation with human-in-the-loop capabilities
https://www.cellpose.org/
BSD 3-Clause "New" or "Revised" License
1.35k stars 388 forks source link

Memory usage increases during subsequent evaluations of cellpose model #539

Open tcompa opened 2 years ago

tcompa commented 2 years ago

Hi there, and thanks for your support.

While working on a different project, with @mfranzon and @jluethi we noticed an unexpected increase of RAM usage during subsequent runs of cellpose segmentation with the nuclei model. I'll report here an example which is as self-contained as possible, but other pieces of information are scattered in our original issues. The question is whether this behavior looks expected/normal, or whether we could try to mitigate it. Also we are wondering if it comes from cellpose or from torch.

Context

Our goal is to perform segmentation of 3D images with cellpose pre-trained nuclei model. We need to segment a certain number of arrays (say 20 of them), and each array may have shape like (30, 2160, 2560) and type uint16. The processing of different arrays (AKA the different cellpose calls) takes place sequentially, on a node which has 64G of memory and access to a GPU. The GPU memory is under control throughout the entire run (around 4 GiB out of 16 are used), while this issue concerns the standard RAM usage (which we monitor via mprof).

Code and results

As a minimal-working example, we load a single array of shape (30,2160,2560) and repeatedly compute the corresponding labels several times. If needed, we can find the best way to share the image folder - or use other data which are already easily available for testing.

The code looks like

import sys
import time

from skimage.io import imread
import numpy as np
from cellpose import core
from cellpose import models

def run_cellpose(img, model):
    t_start = time.perf_counter()
    print(f"START | shape: {img.shape}")
    sys.stdout.flush()
    mask, flows, styles, diams = model.eval(
        img,
        do_3D=True,
        channels=[0, 0],
        net_avg=False,
        augment=False,
        diameter=80.0,
        anisotropy=6.0,
        cellprob_threshold=0.0,
    )
    t_end = time.perf_counter()
    print(f"END  | num_labels={np.max(mask)}, elapsed_time={t_end-t_start:.3f}")
    sys.stdout.flush()
    return mask

# Read 3D stack of images (42 Z planes available)
num_z = 30
stack = np.empty((num_z, 2160, 2560), dtype=np.uint16)
for z in range(num_z):
    stack[z, :, :] = imread(f"images_v1/20200812-CardiomyocyteDifferentiation14-Cycle1_B05_T0001F002L01A01Z{z+1:02d}C01.png")

# Initialize cellpose
use_gpu = core.use_gpu()
model = models.Cellpose(gpu=use_gpu, model_type="nuclei")
print(f"End of initialization: num_z={num_z}, use_gpu={use_gpu}")

nruns = 10

for run in range(nruns):
    print(run)
    run_cellpose(stack, model)

This code runs through, and it takes approximately 320 seconds for each segmentation (finding around 3k labels). The memory trace during the first few iterations of the loop is shown below, and we notice that subsequent runs have a larger and larger memory usage - until this saturates after a few iterations. If we look at the plateau regions in the memory trace, for instance, their values (in GiB) are: 12, 13.8, 14.1, 14.1, .. Also the memory-usage peaks at the end of each cellpose calls are shifting up by a similar amount, accumulating about 2 GiB during the first 2-3 iterations. The simplest explanation would be that cellpose or torch are caching something, but we couldn't identify what is being cached. Is this actually happening? If so, is there a way to deactivate this caching mechanism?

fig_memory

Expected behavior and why it matters

We would expect that subsequent runs on the same exact input require a very similar amount of memory - unless some caching is in-place. The relevance of this issue (for us) is that even if the memory accumulation seems mild (that's only 2 GiB more than expected), in more complex/heavy use cases (including additional parallelism) it may lead to memory errors (as we found in https://github.com/fractal-analytics-platform/fractal/issues/109#issuecomment-1198916009). For this reason we'd really like to keep it under control, possibly by deactivating caching options (if any).

Environment

The python code above is submitted to a SLURM queue, and it runs on a node with a GPU available.

Relevant details on the python environment:

sys.version='3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]'
numpy.__version__='1.23.1'
torch.__version__='1.12.0+cu102'
carsen-stringer commented 2 years ago

I have no idea, have you tried any garbage collecting? you could call cellpose as a process and then it will clean up (all those options are available on the CLI) but then you have to re-read in the saved masks

tcompa commented 2 years ago

Thanks for your comment.

I confirm that adding gc.collect() here and there (both within run_cellpose function, and especially right after each call to this function within the loop) does not lead to any relevant change in the memory trace.

At the moment we cannot go for the CLI path, since this labeling task is part of a more complex platform to process bio-images (https://github.com/fractal-analytics-platform/fractal), where tasks need to be python functions.

For now we'll just keep this issue in mind, and apply mitigation strategies (e.g. working at a lower resolution) if/when needed.