MouseLand / cellpose

a generalist algorithm for cellular segmentation with human-in-the-loop capabilities
https://www.cellpose.org/
BSD 3-Clause "New" or "Revised" License
1.4k stars 402 forks source link

Issue1038 Index error bugfix using tensor dimension checks, reshaping of tensor dimensions prevents out-of-range errors as well as additional dynamics.py performance optimizations #1049

Open derekthirstrup opened 2 weeks ago

derekthirstrup commented 2 weeks ago

Resolve the following issues: Index error due to tensor dim mismatch and appears to minimize extra single pixel seeds often encountered in masks created with cellpose prior to 3.1.0 dynamics.py bug fix version.

Deprecation Warning Resolved:

By replacing .T with .mT and .permute(1, 0), the code adheres to PyTorch's updated tensor manipulation methods, eliminating the deprecation warning and preventing future errors when PyTorch removes support for .T on higher-dimensional tensors. IndexError Resolved:

The added dimension checks and reshaping ensure that tensors have the expected number of dimensions before any operations that assume a specific shape. This prevents the IndexError caused by attempting to access non-existent dimensions. Enhanced Robustness:

Handling cases with no seeds and adding error logging make the code more robust, allowing it to gracefully manage unexpected input scenarios without crashing. Simplified Functionality:

By returning only p from follow_flows, the code eliminates unnecessary data handling and reduces the risk of mismanaging tuple unpacking, thereby streamlining the flow of data through the functions.

dynamics-optimized.py is a drop in replacement for dynamic.py that has further optimizations described below.

Replaced Custom Max Pooling: Leveraged PyTorch's optimized pooling functions (F.max_pool2d, F.max_pool3d) instead of custom implementations to enhance performance and reduce memory usage.

Enhanced Numba Parallelization: Enabled parallel processing in Numba-accelerated functions by setting parallel=True, allowing multi-core CPU utilization for faster computations.

In-Place Tensor Operations: Utilized in-place operations wherever possible to minimize memory overhead and accelerate tensor manipulations.

Reduced CPU-GPU Transfers: Maintained data on the GPU during processing to avoid the latency associated with frequent data transfers between CPU and GPU.

Preallocated and Reused Tensors: Allocated large tensors outside of iterative loops and reused them to improve cache performance and reduce memory allocations.

carsen-stringer commented 1 week ago

thanks so much @derekthirstrup ! I found that the built in max_pool used substantially more memory, is that not what you found?

carsen-stringer commented 1 week ago

here's a minimal test we did, we found that for some reason the built in function makes 3 additional copies of the data:

import torch 
import time
device=torch.device("cuda")
h = torch.randn((1, 429, 960, 960), device=device)
print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9} gb")

def max_pool1d(h, kernel_size=5, axis=1, out=None):
    """ memory efficient max_pool thanks to Mark Kittisopikul 

    for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3

    """
    if out is None:
        out = h.clone()
    else:
        out.copy_(h)

    nd = h.shape[axis]    
    k0 = kernel_size // 2
    for d in range(-k0, k0+1):
        if axis==1:
            mv = out[:, max(-d,0):min(nd-d,nd)]
            hv = h[:, max(d,0):min(nd+d,nd)]
        elif axis==2:
            mv = out[:, :, max(-d,0):min(nd-d,nd)]
            hv = h[:,  :, max(d,0):min(nd+d,nd)]
        elif axis==3:
            mv = out[:, :, :, max(-d,0):min(nd-d,nd)]
            hv = h[:, :,  :, max(d,0):min(nd+d,nd)]
        torch.maximum(mv, hv, out=mv)
    return out

def max_pool_nd(h, kernel_size=5):
    """ memory efficient max_pool in 2d or 3d """
    ndim = h.ndim - 1
    hmax = max_pool1d(h, kernel_size=kernel_size, axis=1)
    hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2)
    if ndim==2:
        del hmax
        return hmax2
    else:
        hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax)
        del hmax2 
        return hmax

kernel_size = 5

tic = time.time()
hmax_mark = max_pool_nd(h, kernel_size=kernel_size)
print(f"time {time.time()-tic:.5f} sec")
print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
del hmax_mark
torch.cuda.empty_cache()

from torch.nn.functional import max_pool1d, max_pool2d, max_pool3d, pad
tic = time.time()
hmax = max_pool3d(h, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
print(f"time {time.time()-tic:.5f} sec")
print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9} gb")

output:

mem used: 1.581 gb, max mem used: 1.5814656 gb
time 0.03628 sec
mem used: 3.163 gb, max mem used: 4.744 gb
time 0.00302 sec
mem used: 3.163 gb, max mem used: 6.3258624 gb