MouseLand / cellpose

a generalist algorithm for cellular segmentation with human-in-the-loop capabilities
https://www.cellpose.org/
BSD 3-Clause "New" or "Revised" License
1.38k stars 393 forks source link

Memory management for parallel processing #207

Closed loomcode closed 1 month ago

loomcode commented 3 years ago

I understand that pyTorch doesn't offer the same memory management tools as tensorflow so you can't split GPU memory. This becomes an issue if you try to run cellpose in parallel because the first process occupies all of the memory leading to out of memory errors. This is mostly a problem because cellpose runs on one image at a time and then the postprocessing steps require the brunt of the processing time. All the while the GPU memory is reserved. If you could run all images on the network and then take those results into parallel postprocessing it seems like cellpose could run much faster.

loomcode commented 3 years ago

After splitting models.CellposeModel.init.eval into two functions I was able to get nearly a 6 fold improvement in run time. One function (run_gpu) generates just the masks and saves them to disk. The second function (run_cpu), calculates the masks using a modified version of the dynamics script where all references to Torch were removed. I'm running on a 20 core box with a gtx1080 gpu, so leveraging the additional cores makes a big difference in my case. In this particular use case I was using cellpose to calculate nuclei segmentation masks on 38 (540 x 540) images.

original implementation: flow+mask = 466.86s modified implementaiton: flow+mask = 81.21s (all gains were observed in the mask generation step)

atarkowska commented 3 years ago

Hi, thanks for sharing, could you provide more details about your solution?

loomcode commented 3 years ago

Sure thing. Beware that I've only run this on 2D nuclear images. First I installed cellpose on a virtual env. Then I modified the models script in cellpose by adding this function to the CellposeModel class:

    def eval_gpu(self, imgs, batch_size=8, channels=None, normalize=True, invert=False,
             rescale=None, diameter=None, do_3D=False, anisotropy=None, net_avg=True,
             augment=False, tile=True, tile_overlap=0.1,
             resample=False, interp=True, flow_threshold=0.4, cellprob_threshold=0.0, compute_masks=True,
             min_size=15, stitch_threshold=0.0, progress=None):

        x, nolist = convert_images(imgs.copy(), channels, do_3D, normalize, invert)

        nimg = len(x)
        self.batch_size = batch_size
        if rescale is None:
            if diameter is not None:
                if not isinstance(diameter, (list, np.ndarray)):
                    diameter = diameter * np.ones(nimg)
                rescale = self.diam_mean / diameter
            else:
                rescale = np.ones(nimg)
        elif isinstance(rescale, float):
            rescale = rescale * np.ones(nimg)

        iterator = trange(nimg) if nimg > 1 else range(nimg)

        if isinstance(self.pretrained_model, list) and not net_avg:
            self.net.load_model(self.pretrained_model[0], cpu=(not self.gpu))
            if not self.torch:
                self.net.collect_params().grad_req = 'null'

        for i in iterator:
            img = x[i].copy()
            shape = img.shape
            # rescale image for flow computation
            img = transforms.resize_image(img, rsz=rescale[i])
            y, style = self._run_nets(img, net_avg=net_avg,
                                      augment=augment, tile=tile,
                                      tile_overlap=tile_overlap)

        return y, style, img, shape, rescale[0]

which is just the part of the eval function which generates the flows. Next I created a copy of the dynamics script and removed all references to Torch or gpu usage. I named the script no_torch and saved it in the cellpose path.

Finally, I made a class to call cellpose and run the processing:

from cellpose import models, io_pose
import os
from tqdm import tqdm
import numpy as np
from multiprocessing import Pool, cpu_count
from cellpose import no_torch, utils, transforms
from cv2 import imwrite as imwcv2, INTER_NEAREST

#worker needs to be outside of class to avoid pickle error
def eval_cpu(ys, shapes, flow_thresholds, rescales):
    cellprob = ys[:, :, -1]
    dP = ys[:, :, :2].transpose((2, 0, 1))
    niter = 1 / rescales * 200
    p = no_torch.follow_flows(-1 * dP * (cellprob > 0.0) / 5., niter=niter)
    maski = no_torch.get_masks(p, iscell=(cellprob > 0.0), flows=dP, threshold=flow_thresholds)
    maski = utils.fill_holes_and_remove_small_masks(maski)
    maski = transforms.resize_image(maski, shapes[0], shapes[1], interpolation=INTER_NEAREST)
    return maski

class CellPoseRunner():

    def __init__(self, run_dir, masks_dir):
        self.run_dir = run_dir
        self.masks_dir = masks_dir
        self.rescales = []
        self.img_shapes = []
        self.img_names_list = []
        self.flows = []

    def get_flows(self, cpmodel_path, nucleus_avg_diameter=None, gpu=True):
        mxnet = False
        device, gpu = models.assign_device((not mxnet), gpu)

        # init the model
        model = models.CellposeModel(gpu=gpu, device=device, pretrained_model=cpmodel_path, torch=(not mxnet))

        # calculate flows for each image along with necessary downstream info
        image_names = io_pose.get_image_files(self.run_dir, '_masks', None)
        for image_name in tqdm(image_names):
            image = io_pose.imread(image_name)
            y, style, img, shape, rescale = model.eval_gpu(image, batch_size=8, rescale=None, diameter=nucleus_avg_diameter)
            self.rescales.append(rescale)
            self.img_shapes.append(shape)
            self.flows.append(y)
            self.img_names_list.append(image_name)

    def get_masks(self, flow_t=0.0):
        assert len(self.rescales) > 0 #require that masks were built first

        print("Building masks from flows..")
        flow_ts = [flow_t for i in range(len(self.flows))]
        run_data = np.array([self.flows, self.img_shapes, flow_ts, self.rescales]).T

        p = Pool(processes=cpu_count()-1)
        res = p.starmap(eval_cpu, run_data)

        for i, d in enumerate(self.img_names_list):
            image_n, dtype = os.path.split(d)[1].split('.')
            imwcv2(os.path.join(self.masks_dir, image_n+"."+dtype), res[i].astype(np.uint16))

Where the eval_cpu function is just a reduced copy of the part of models.eval that creates the masks and get_flows is basically the same as the code in cellpose main used to do the primary processing. If you have trouble importing io_pose that's because I renamed the cellpose io script so that I wouldn't run into conflicts with Python's io library. Just init the class, then call get_flows() followed by get_masks(). The class constructor parameters are the path to where the images are stored and the desired output directory respectively. I hope this helps!

sumankhan19 commented 3 years ago

Hi I tried to follow your instruction to do the same, but sadly I am not an advance programmer. It ended up in a disaster and a barrage of errors.

Could I request you to kindly share your codes as a whole?

loomcode commented 3 years ago

Hey sumankhan19,

This is the entirety of the code excepting the two scripts that I mentioned modifying. There's a good possibility that there's a versioning issue. If you run pip3 show cellpose from your virtual python environment you can get the cellpose version number. To be explicit, here's what I'm using:

cellpose version: 0.6.1 linux version: linux mint 19.3/Ubuntu 18.04

Also, even though this might work on cell segmentation I only tested it on nuclei segmentation.

if you can post an email address I can add you to a slack channel where we can discuss further.

sumankhan19 commented 3 years ago

suman.khan@weizmann.ac.il

kindly add me! Thanks!

carsen-stringer commented 3 years ago

Using parallel pool has the disadvantage that it will slow down processing on windows if you are not restricting processing to a single thread for computation (or at least this used to be the case). Is your no_torch implementation restricted to a single core?

also in what case is the post-processing most of the time? is it because you have flow_threshold>0? the torch-based implementation is GPU accelerated -- have you tried using the GPU for this?

I can see advantages for running all the images on the GPU at once -- and actually there is an option to tile across batches if you have images of the same size for 3D -- this could be extended for 2D but I'm not sure if it would be useful if your images are very big. And then for the dynamics you could run multiple images simultaneously through the grid_sample layers on the GPU (right now I'm using a batch size of 1).

loomcode commented 3 years ago

I remember there being a performance discrepancy between Windows and Linux for python multiprocessing but I'm not sure if that continues to be the case or not. I don't use Windows when I have a choice, and what I've implemented here has only been run in Ubuntu 18. In this case, no_torch is run in parallel and appears to be thread safe. To further reduce post-processing time I have been running with flow threshold=0. I have tried to use the GPU for post-processing and as expected it is much faster than the non-gpu version, but when compared to multiprocessing it is much slower for my case.

To illustrate why, let's say that that I'm processing 30 images on a 40 core machine. When run without GPU acceleration assuming each image takes 60 seconds for post-processing that's 30 minutes of run time for post-processing serially. If GPU acceleration reduces the serial processing of images down to 10 seconds then the result is far better at 5 minutes run time. On the other hand, if post-processing is run on all 30 images concurrently without GPU acceleration then in the optimal case the total run time is 1 minute.

Whether the GPU accelerated version or the parallel CPU version is faster depends on how many images you are processing and the particulars of your system (number of cores and GPU). In my case, where I'm consistently processing more than 10 images at a time and I happen to have a moderately capable GPU plus a 40 thread server, the parallel processing CPU case is usually a few times faster, but mileage may vary.

carsen-stringer commented 3 years ago

Are your diameters less than 30? That could explain the 466s timing. I found some speed bottlenecks, can you please try the latest version of the code from github and report your timings?

for the GPU accelerated interpolation you need to be running the torch version, are you running that?

My current timings with 38 images of size 2048x2048 and diameter=30:

With GPU with flow_threshold=0.4 is flow+mask computation 99.33; with flow_threshold=0.0 is flow+mask computation 51.60

With CPU+interp=False with flow_threshold=0.4 is flow+mask computation 437.30; with flow_threshold=0.0 is flow+mask computation 190.01.

If this doesn't help enough I will consider splitting the steps as you've recommended.

loomcode commented 3 years ago

We're dealing with low resolution images so our nucleus diameters average about 8 with a mean cell count of ~6500 per frame. I just upgraded to a new machine with a sweet sweet RTX 3090 and I'm struggling to get it to work with Torch. Once that happens I'll try the system on the new release.

loomcode commented 3 years ago

I just tested on the new machine using the following:

interp=False
flow_threshold=0
--fast_mode
--nclasses=1
--diameter=8
--no_npy

Running on version 0.6.1 of cellpose, on a stack of 20 - (512x512) image examples I get a total run time of: 149.6 seconds.

Running using the multiprocessing approach, run time is: ~30 seconds. Because I'm using a 64 thread machine, I expect to see more gains with larger image stacks.

carsen-stringer commented 1 month ago

we will not be using pool, I'd recommend using the GPU to get speed-ups, and we will look into running multiple dynamics flows simultaneously on the GPU