Closed tcompa closed 2 years ago
As of 3c1c431d5919c3901d4d88e26160b097b098110a, there are a few working examples of this integration, see for instance examples/UZH_1_well_2x2_sites
(single-well) or examples/UZH_4_well_2x2_sites
(multi-well).
A parameter file looks like this:
$ cat wf_params_uzh_4_well_2x2_sites.json
{
"workflow_name": "uzh_4_well_2x2_sites",
"dims": [2, 2],
"coarsening_xy": 2,
"coarsening_z": 1,
"num_levels": 5,
"channel_file": "../wf_params_uzh_cardiac_channels.json",
"path_dict_corr": "../wf_params_uzh_cardiac_illumination.json",
"image_labeling": {"coarsening_xy": 2, "labeling_level": 0, "labeling_channel": "A01_C01", "num_threads": 1, "relabeling": 1, "anisotropy": 6.1538, "diameter": 35.0, "cellprob_threshold": 0.0},
"image_labeling_whole_well": {"coarsening_xy": 2, "labeling_level": 2, "labeling_channel": "A01_C01", "diameter_level0": 35.0, "cellprob_threshold": 0.0}
}
(note: the structure got somewhat similar to the idea of a pipeline JSON file, but this is not part of that work - which will only take place in the incoming server-based version)
These examples include the whole set of current tasks (including illumination correction, per-FOV labeling, MIP creation, per-well MIP labeling). Note that the whole-well labeling only takes place on MIP images (this is hardcoded, at the moment).
For the per-FOV labeling in the multi-well case, parsl opens as many GPU jobs as possible (ideally as many as the number of wells).
In the running folder, there are some log files named like LOG_image_labeling_B_03
or LOG_image_labeling_B_03_whole_well
, which at the moment are only for debugging.
Let's run a couple more tests, including those with larger (9x8) wells, before closing this issue.
A quick comment while running multi-well workflows: unsurprisingly, they seem to take a bit longer than what one would expect from similar single-well examples (this also applies to labeling tasks), which is likely related to https://github.com/fractal-analytics-platform/fractal/issues/57.
EDIT: what I wrote above is likely true, but in the specific example I was looking at (UZH_4_well_2x2_sites
) the issue was rather due to the number of Z planes (19, 42, 29, 42, for the four wells).
The following Fractal examples ran through
A run with 10 5x5 wells (and 19 Z planes) lead to a memory error, when run with num_threads=2
. I will retry with num_threads=1
.
There's still something wrong with the 10-wells example, which successfully completes only 8 out of 10 labeling executions:
LOG_image_labeling_B_09: 25/25 completed
LOG_image_labeling_B_11: 25/25 completed
LOG_image_labeling_C_08: 25/25 completed
LOG_image_labeling_C_10: 25/25 completed
LOG_image_labeling_D_09: 25/25 completed
LOG_image_labeling_D_11: 25/25 completed
LOG_image_labeling_E_08: 25/25 completed
LOG_image_labeling_E_10: 25/25 completed
LOG_image_labeling_F_09: 9/25 completed
LOG_image_labeling_F_11: 18/25 completed
Total: 227/250 completed
The parsl error is of the ManagerLost
kind, which is not very informative.
Somewhere in the logs, I observe:
$ cat runinfo/000/submit_scripts/parsl.slurm.1658214962.7334816.submit.stderr
0: slurmstepd: Step 9259927.0 exceeded virtual memory limit (64097388 > 63543705), being killed
0: slurmstepd: *** STEP 9259927.0 ON pelkmanslab-slurm-worker-040 CANCELLED AT 2022-07-19T09:41:24 ***
srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
slurmstepd: Exceeded job memory limit
slurmstepd: Exceeded job memory limit
slurmstepd: *** JOB 9259927 ON pelkmanslab-slurm-worker-040 CANCELLED AT 2022-07-19T09:41:24 ***
which points at a very explicit virtual-memory error (reaching the 64 G available). Is this related to https://github.com/fractal-analytics-platform/fractal/issues/33?
To be verified.
Addendum to the last comment: The two failing jobs are the two that started last (there are only 8 GPU nodes available, for 10 tasks). This is hardly a coincidence.
One more piece of information, trying to pinpoint the source of the memory error when per-FOV labeling 10 wells.
To test the multi-well case and (temporarily) avoid memory errors, I ran a workflow where per-FOV segmentation of 10 5x5 wells takes place at level 1. This is allowed by the new ROI-based labeling (see https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/19). I also reduced max_blocks_gpu
to 1
, so that all 10 labeling tasks run (one by one) on the same node.
Here are the CPU and memory traces (the CPU trace is there just as a reference, as it is quite trivial):
The first two blocks of execution (up to time ~1800 s) correspond to other (non-labeling) tasks, and then there are 10 blocks of per-FOV labeling (one for each well). The annoying feature is the build up of memory usage between the 1st and 3rd segmentation tasks, from ~5 G to ~16 G. This is unexpected, as each task should use approximately the same memory (notice that number of Z planes is constant across wells, in this dataset).
Thus the actual question becomes something like: Why is memory accumulating when rerunning a task on the same node? Is there some garbage-collection issue at the end of the tasks (aka parsl's python apps)?
Quick update after discussing with @mfranzon
We tend to think that this is related to cellpose (and probably torch) not freeing up memory, rather than parsl. Can we explicitly force a garbage collection / cache clean-up after a task? To be tested.
Possibly related: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530 https://forum.image.sc/t/napari-cellpose-out-of-memory/60742/3
Sounds like a good assumption to test, because we've only seen this issue come up with labeling jobs after all :)
Quick update:
For debugging, we now import the Cellpose model within the segment_FOV
function, that is, once per FOV (instead of passing around the same model
object for all FOVs). That is, the block
use_gpu = core.use_gpu()
model = models.Cellpose(gpu=use_gpu, model_type=model_type)
is always re-executed.
What is unexpected, and possibly related to the memory issue, is the timing of the first line (core.use_gpu()
). When first running on a given node, it takes a couple of seconds. When running again on the same node, it takes something like 0.001 seconds, meaning that some trace/cache of the torch initialization is preserved. Digging a bit further, the use_gpu
function maps to these simple and apparently harmless lines:
device = torch.device('cuda:' + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device)
It would be great to have a torch.reset() function (we are not the only ones who think something similar), but at the moment we only have torch.cuda.empty_cache()
, which doesn't seem to work.
This update points towards a likely problem, but not towards an obvious solution.
Notice that we now tend to exclude that the problem is due to some attributes of Cellpose model being kept in memory, since the model is now always re-initialized within segment_FOV
.
Hmm, and with this new model initialization within each segment_FOV call, we still run into the same memory issue?
Also, is a torch.reset() something we could safely do if multiple jobs run on parallel on the same GPU?
Hmm, and with this new model initialization within each segment_FOV call, we still run into the same memory issue?
Our understanding (by now) is that the Cellpose model initialization is irrelevant, and the problem comes from torch. And indeed the multiple initializations of the model apparently do not solve memory issue.
Also, is a torch.reset() something we could safely do if multiple jobs run on parallel on the same GPU?
This is not really an issue, since torch does not have a reset()
function. It's more something we'd like to have, to free the "cached" (?) memory.
Can we reproduce this in a synthetic setup by e.g. just running 2 or 10 cellpose models sequentially in a single script?
Could be that it's a new issue or related to a specific cellpose/torch/gpu. But I've ran scripts that used cellpose for model inference in the past on 1000s of 3D images, each of them close to the memory limit. And that was all sequentially on a single GPU. So would be kind of surprised if there is a very general issue with cellpose or torch memory handling...
Can we reproduce this in a synthetic setup by e.g. just running 2 or 10 cellpose models sequentially in a single script?
Yes.
By now I'm starting with the original image_labeling
task (with a global model definition), and I only add a gc.collect()
after each call (which seems to do nothing).
Here is the script that runs four wells sequentially:
import sys
import time
import os
import shutil
from fractal.tasks.image_labeling import image_labeling
import gc
zarrurl = "/data/active/fractal/tests/Temporary_data_UZH_4_well_2x2_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/"
wells = ["B/03/0/", "B/05/0/", "C/04/0/", "D/05/0/"]
print(wells)
print()
for well in wells:
label_dir = zarrurl + well + "labels"
if os.path.isdir(label_dir):
shutil.rmtree(label_dir)
print("Cleared label folders")
print()
with open("times.dat", "w") as f:
t0 = time.perf_counter()
for well in wells:
print(well)
sys.stdout.flush()
image_labeling(
zarrurl + well,
coarsening_xy=2,
labeling_level=2,
labeling_channel="A01_C01",
chl_list=["A01_C01"],
num_threads=4,
relabeling=1,
)
t1 = time.perf_counter()
f.write(f"{t1-t0}\n")
f.flush()
gc.collect()
And here is the memory trace (also using https://bbengfort.github.io/2020/07/read-mprofile-into-pandas), where something is clearly accumulating between different image_labeling
calls (especially between the first and second ones):
By the way: this last comment confirms that parsl is not involved in this memory issue (since the script was executed directly via SLURM, with parsl never appearing).
Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)?
Reason I'm asking: if torch does some optimization in the background and decides when to clear what memory, I wouldn't care about it, as long as it does it well enough to avoid out-of-memory errors. So I'm not sure how concerning the fact that some memory is accumulating between runs is, unless it is not freed up when needed.
What is your expectation here: Does torch do any fancy optimization for when to free up which memory? Because if it was an actual memory leak, I would expect linear accumulation.
And for my understanding: we are always talking about CPU memory here, right? Or do we go into GPU memory being an issue?
Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)?
Sure, let's try. My guess is that this is exactly the memory error of https://github.com/fractal-analytics-platform/fractal/issues/109#issuecomment-1188958357, but let's check it explicitly.
What is your expectation here: Does torch do any fancy optimization for when to free up which memory? Because if it was an actual memory leak, I would expect linear accumulation.
It's clearly not linear, see the saturation to a plateau in the comment above: https://user-images.githubusercontent.com/3862206/181180402-409d000f-972c-4e8f-85eb-f05b0f44904b.png.
And for my understanding: we are always talking about CPU memory here, right? Or do we go into GPU memory being an issue?
Yes, this is all standard CPU memory. GPU memory errors may appear, but (in our experience) only when running many (e.g. 10) simultaneous Cellpose calculations on the same GPU.
Hmm, ok. Let's see if that test can explicitly push it over the limit.
Also, I wonder whether there are some torch parameters that would tell it about available memory. It does seem to handle memory cleanup sometimes, but maybe it's optimized for classical systems where it could go into swap a little bit? Maybe there are torch parameters we could set to make it more aggressive in CPU memory cleanup?
Also, I wonder whether there are some torch parameters that would tell it about available memory. It does seem to handle memory cleanup sometimes, but maybe it's optimized for classical systems where it could go into swap a little bit? Maybe there are torch parameters we could set to make it more aggressive in CPU memory cleanup?
The only related option we could find is no_grad, but this introduces an explicit change of the function and we are not able to say whether it modifies Cellpose behavior. Is there a trivial answer?
(btw, we have not tested it yet)
Other than that, we'd be glad to test other relevant torch options, if you discover any.
Ok, just googling a bit. Have you had a look at setting LRU_CACHE_CAPACITY=1
as an environment variable? Or set to something different. There's a few interesting discussions on memory usage in pytorch here and here that come to that conclusion.
e.g. testing having this included would be interesting:
import os
os.environ["LRU_CACHE_CAPACITY"] = "1"
Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)?
Sure, let's try. My guess is that this is exactly the memory error of #109 (comment), but let's check it explicitly.
I confirm what I said: the example we are looking at does yield the expected memory error (AKA there is no smart optimization of memory usage by torch, but rather a memory accumulation as several image_labeling
tasks are called).
Here is the memory trace (details in the plot title), and the memory error appears during processing of the third well.
Just as a reference, the detailed traceback is
B/09/0/
B/11/0/
C/08/0/
Traceback (most recent call last):
File "Many_segmentations.py", line 24, in <module>
image_labeling(
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/image_labeling.py", line 291, in image_labeling
write_pyramid(
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/lib_pyramid_creation.py", line 69, in write_pyramid
level0 = to_zarr_custom(
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/lib_to_zarr_custom.py", line 64, in to_zarr_custom
output = array.to_zarr(
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 2828, in to_zarr
return to_zarr(self, *args, **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 3591, in to_zarr
return arr.store(z, lock=False, compute=compute, return_stored=return_stored)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 1752, in store
r = store([self], [target], **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 1214, in store
store_dlyds = persist(*store_dlyds, **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/base.py", line 904, in persist
results = schedule(dsk, keys, **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/threaded.py", line 89, in get
results = get_async(
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 511, in get_async
raise_exception(exc, tb)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 319, in reraise
raise exc
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 224, in execute_task
result = _execute_task(task, data)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/core.py", line 119, in _execute_task
return func(*(_execute_task(a, cache) for a in args))
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/utils.py", line 41, in apply
return func(*args, **kwargs)
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/image_labeling.py", line 57, in segment_FOV
mask, flows, styles, diams = model.eval(
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 227, in eval
masks, flows, styles = self.cp.eval(x,
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 536, in eval
masks, styles, dP, cellprob, p = self._run_cp(x,
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 625, in _run_cp
masks, p = dynamics.compute_masks(dP, cellprob, niter=niter,
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/dynamics.py", line 718, in compute_masks
mask = get_masks(p, iscell=cp_mask)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/dynamics.py", line 636, in get_masks
h,_ = np.histogramdd(tuple(pflows), bins=edges)
File "<__array_function__ internals>", line 180, in histogramdd
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/numpy/lib/histograms.py", line 1101, in histogramdd
hist = hist.astype(float, casting='safe')
numpy.core._exceptions.MemoryError: Unable to allocate 2.60 GiB for an array with shape (61, 2202, 2602) and data type float64
mprof: Sampling memory every 0.1s
running new process
but that's not relevant as it could happen on any other line of Cellpose code.
Ok, just googling a bit. Have you had a look at setting
LRU_CACHE_CAPACITY=1
as an environment variable? Or set to something different. There's a few interesting discussions on memory usage in pytorch here and here that come to that conclusion.e.g. testing having this included would be interesting:
import os os.environ["LRU_CACHE_CAPACITY"] = "1"
Thanks for the links, but @mfranzon noticed that this should be already fixed in pytorch>1.5 (and we are on 1.12.0+cu102
). Too bad ;)
And indeed by adding these two lines we find a behavior which is fully compatible with https://github.com/fractal-analytics-platform/fractal/issues/109#issuecomment-1197929744 (notice that reproducing the memory error each time takes too long, so we test on the smaller dataset).
Hmm, thanks for the thorough checks!
Does it also happen when calling the pure cellpose model inference on e.g. some synthetic test data (without our image_labeling
wrapper that adds some complexity and some dask logic)? If we can make such a test case and still run out of memory, we could then report this to the cellpose repo.
WARNING: I keep forgetting that the dataset with 4 2x2 wells has varying number of Z planes, which makes the interpretation of memory traces much less obvious. This information is also wrong in the plot titles (where I wrote "10 Z planes").
Let's focus only on the case with 10 5x5 wells (and constant number of Z levels).
Does it also happen when calling the pure cellpose model inference on e.g. some synthetic test data (without our
image_labeling
wrapper that adds some complexity and some dask logic)? If we can make such a test case and still run out of memory, we could then report this to the cellpose repo.
We are trying to at least remove the image_labeling
/dask logic, but keeping actual data [early tests with artificial data (like np.zeros
) were not so informative, but we can try again].
I keep forgetting that the dataset with 4 2x2 wells has varying number of Z planes, which makes the interpretation of memory traces much less obvious. This information is also wrong in the plot titles (where I wrote "10 Z planes").
Ah, but wouldn't that potentially explain the out of memory problem? Do we know whether we run out of memory in cases with "too many" Z planes (=> where the user should choose a lower pyramid level)? And yes, we'll certainly need a better (small) test case then to verify whether memory load increases. Does the 10 well, 5x5 work for that or is something smaller required?
I keep forgetting that the dataset with 4 2x2 wells has varying number of Z planes, which makes the interpretation of memory traces much less obvious. This information is also wrong in the plot titles (where I wrote "10 Z planes").
Ah, but wouldn't that potentially explain the out of memory problem?
The memory error is for a dataset with constant number of Z planes (10), and it's not affected by this last remark. Other memory traces shown in this issue (those for the 2x2 wells) are the ones affected - not because they are wrong, but because their interpretation is incomplete.
Does the 10 well, 5x5 work for that or is something smaller required?
We are now trying with the 10-5x5 dataset, working at the per-well level in 3D and at pyramid level 3 (this corresponds to calling Cellpose on (19, 1350, 1600)
array). Let's see if running times is under control, or we will try level 4.
We tried to reduce complexity as far as possible, while keeping the problematic behavior there. We prepared a mock of the image_labeling
task, which has only its core components (read, segment, write) in two flavours:
num_threads=2
);Code:
We consider the usual 10-5x5 dataset, but we only segment a 2x2 subset of each 5x5 well. Segmentation is performed at level 0.
The memory trace of these two runs is below. Black lines in the figure are rolling averages, as a guide to the eye. Comments:
These examples seem robust, but we noticed that working with artificial data (AKA np.zeros
) or with much smaller arrays (AKA lower resolution) doesn't always lead to clear indications of memory accumulation - and this slows down further testing. We'll think a bit more about which options we have, before running many more long tests.
A final detail, the size of datasets for different wells (after selecting level 0 and channel 0) is very homogeneous:
$ du -sh -L /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/*/*/*/0/0 | sort -k2
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/B/09/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/B/11/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/C/08/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/C/10/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/D/09/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/D/11/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/E/08/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/E/10/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/F/09/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/F/11/0/0/0
I just updated the figure in the previous comment. The memory accumulation along the 10-wells example is clear even in the case where cellpose is called sequentially (orange line), without any use of dask.delayed
functions (note that dask is still used to load/write data, but that should really be irrelevant).
Comment by me and @mfranzon: This memory issue is taking a lot of time and effort, but it's likely due to something in external libraries (cellpose and/or torch). Perhaps we should move the discussion to a cellpose issue? (provided we manage to write a reasonably simple test, possibly without our actual data).
Wow, yeah! I am still surprised with the issue, but now that it doesn't seem to be a dask issue: If we can generate a test case of this happening, I'd be very much in favor of opening a cellpose issue with this! If you need to use some example data from my UZH test set: That data isn't a precious secret, so feel free to use that.
And then that's probably a good point to stop our digging into it. Either there is a fix from the external library (cellpose in this case) or we need to think bigger-picture: How do we run external libraries that may not manage memory well (/or that may have weird dependencies etc.)
For a simple test case, could you just load the same image repeatedly (e.g. even just as aPNG image using imageio, doesn't need to be Zarr if that makes the test complicated) and loop cellpose over it in a basic for loop, e.g. not even saving the results.
The discussion should probably continue in https://github.com/MouseLand/cellpose/issues/539.
Great that we have this escalated now. I'd say work on the segmentation milestone is done then from our side. Let's follow the cellpose discussion to see if there is a good workaround and otherwise think broader about handling libraries with potential memory issues :)
The labeling task is part of fractal, and remaining issues are unrelated to this one. Closing.
Integrate
image_labeling
andimage_labeling_whole_well
in fractal (for the current working version, not for the incoming server-based one).