Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.51k stars 1.01k forks source link

Memory efficient sliding window inference #6427

Open razorx89 opened 1 year ago

razorx89 commented 1 year ago

Is your feature request related to a problem? Please describe. Large input volumes have to be processed via a sliding window algorithm, otherwise OOMs can happen quickly. There are two constraining properties which can cause an OOM: image size and number of predicted classes. The sliding_window_inference in MONAI allocates a FP32 probability aggregation buffer of size BxCxDxHxW and FP32 weight aggregation buffer of size BxDxHxW. In most cases, we are only interested in the final prediction anyway (class with highest probability).

The following example shows that even without running any model, the peak memory usage is very high (and that is a conservative image size if you think about e.g. high resolution isotropic wholebody CTs):

import torch
from monai.inferers import sliding_window_inference

num_classes = 100
data = torch.rand(1, 1, 384, 384, 256, dtype=torch.float32, device='cuda:0')
model = lambda x: x * torch.rand(x.shape[0], num_classes, x.shape[2], x.shape[3], x.shape[4], dtype=x.dtype, device=x.device)

out = sliding_window_inference(data, (128, 128, 128), 1, model)
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from large pool |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from large pool |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   35958 MB |   35958 MB |   35958 MB |       0 B  |
|       from large pool |   35956 MB |   35956 MB |   35956 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |   10853 MB |   12646 MB |   12646 MB |
|       from large pool |       0 B  |   10852 MB |   12638 MB |   12638 MB |
|       from small pool |       0 B  |       1 MB |       8 MB |       8 MB |
|---------------------------------------------------------------------------|
| Allocations           |       2    |      10    |     221    |     219    |
|       from large pool |       2    |       9    |     205    |     203    |
|       from small pool |       0    |       3    |      16    |      16    |
|---------------------------------------------------------------------------|
| Active allocs         |       2    |      10    |     221    |     219    |
|       from large pool |       2    |       9    |     205    |     203    |
|       from small pool |       0    |       3    |      16    |      16    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      12    |      12    |      12    |       0    |
|       from large pool |      11    |      11    |      11    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       5    |      58    |      58    |
|       from large pool |       0    |       4    |      53    |      53    |
|       from small pool |       0    |       2    |       5    |       5    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

If the image size or class count is further increased (e.g. 384x384x384) even on very powerful cards like A6000 with 48GB OOMs occur and require moving the aggregation to cpu memory. But since all computations for adding probability crops to the aggregation buffer are performed on the cpu instead of the gpu, it is going to be very slow.

Describe the solution you'd like I implemented a sliding_window_inference_with_reduction method, which essentially performs two sliding windows. The outer loop iterates over a single dimension and performs the reduction operation, thus taking probabilities from the inner loop, apply e.g. argmax and store the results in the output buffer of size BxDxHxW with integer datatype (e.g. uint8). The inner loop performs a 2.5d sliding window inference of a slab of data. Since the outer loop can also iterate with some overlap, some of the probabilities of the inner loop will be used for initializing the buffer for the next inner loop iteration.

import torch

from inferer import sliding_window_inference_with_reduction

num_classes = 100
data = torch.rand(1, 1, 384, 384, 256, dtype=torch.float32, device='cuda:0')
model = lambda x: x * torch.rand(x.shape[0], num_classes, x.shape[2], x.shape[3], x.shape[4], dtype=x.dtype, device=x.device)

out = sliding_window_inference_with_reduction(data, (128, 128, 128), 1, model)

print(torch.cuda.memory_summary())
inferer.py ```python from typing import Any, Callable, Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F from monai.data.meta_tensor import MetaTensor from monai.data.utils import ( compute_importance_map, dense_patch_slices, get_valid_patch_size, ) from monai.inferers.utils import _get_scan_interval from monai.utils import ( BlendMode, PytorchPadMode, convert_data_type, convert_to_dst_type, fall_back_tuple, look_up_option, ) from tqdm import tqdm def sliding_window_inference_with_reduction( inputs: torch.Tensor, roi_size: Union[Sequence[int], int], sw_batch_size: int, predictor: Callable[..., torch.Tensor], overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT, sigma_scale: Union[Sequence[float], float] = 0.125, padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, cval: float = 0.0, sw_device: Optional[Union[torch.device, str]] = None, device: Optional[Union[torch.device, str]] = None, reduction_fn: Callable[..., torch.Tensor] = torch.argmax, reduction_dim: int = 1, output_dtype: torch.dtype = torch.uint8, progress: bool = False, *args: Any, **kwargs: Any, ) -> torch.Tensor: compute_dtype = inputs.dtype num_spatial_dims = len(inputs.shape) - 2 if overlap < 0 or overlap >= 1: raise ValueError("overlap must be >= 0 and < 1.") batch_size, _, *orig_image_size = inputs.shape if device is None: device = inputs.device if sw_device is None: sw_device = inputs.device roi_size_safe: Tuple[int] = fall_back_tuple(roi_size, orig_image_size) image_size = tuple( max(orig_image_size[i], roi_size_safe[i]) for i in range(num_spatial_dims) ) pad_size = [] for k in range(len(inputs.shape) - 1, 1, -1): diff = max(roi_size_safe[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) if max(pad_size) > 0: inputs = F.pad( inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval, ) patch_size = get_valid_patch_size(image_size, roi_size_safe) importance_map = compute_importance_map( patch_size=patch_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, ) importance_map = torch.clamp( importance_map, min=max(importance_map[importance_map != 0].min().item(), 1e-3), ) importance_map = convert_data_type( importance_map, torch.Tensor, sw_device, compute_dtype )[0] # Identifiy outer and inner dimensions for the sliding window and aggregation outer_dim = 2 # TODO Find heuristic for this # Allocate buffers output = torch.empty( tuple(x for i, x in enumerate(inputs.shape) if i != reduction_dim), dtype=output_dtype, device=device, ) slab_probabilities: Optional[torch.Tensor] = None slab_weights = torch.zeros( tuple( inputs.shape[i] if i != outer_dim + 2 # account for batch and channel dimensions else roi_size_safe[outer_dim] for i in range(inputs.ndim) ), dtype=compute_dtype, device=sw_device, ) # Iterate over outer dimension and aggregate a full slab of reduced predictions outer_step_size = int(roi_size_safe[outer_dim] * (1 - overlap)) outer_indices = list( range( 0, inputs.shape[outer_dim + 2] - roi_size_safe[outer_dim] + 1, outer_step_size, ) ) if outer_indices[-1] != image_size[outer_dim] - roi_size_safe[outer_dim]: outer_indices.append(image_size[outer_dim] - roi_size_safe[outer_dim]) last_outer_dim_idx = -1 for outer_idx, outer_dim_idx in enumerate( tqdm(outer_indices, leave=True, position=0) if progress else outer_indices ): # Move old probabilities and weights based on the actual step size of this slab if outer_idx > 0: assert slab_probabilities is not None actual_step_size = outer_dim_idx - last_outer_dim_idx assert 0 < actual_step_size <= outer_step_size new_slices = tuple( slice(None) if i != outer_dim + 2 # account only for batch dimension else slice(None, -actual_step_size) for i in range(slab_probabilities.ndim) ) old_slices = tuple( slice(None) if i != outer_dim + 2 # account only for batch dimension else slice(actual_step_size, None) for i in range(slab_probabilities.ndim) ) null_slices = tuple( slice(None) if i != outer_dim + 2 # account only for batch dimension else slice(-actual_step_size, None) for i in range(slab_probabilities.ndim) ) slab_probabilities[new_slices] = slab_probabilities[old_slices] slab_weights[new_slices] = slab_weights[old_slices] slab_probabilities[null_slices] = 0.0 slab_weights[null_slices] = 0.0 # Take slab of input images and apply padding slab_slices = tuple( slice(None) if i != outer_dim + 2 # account for batch and channel dimensions else slice(outer_dim_idx, outer_dim_idx + roi_size_safe[outer_dim]) for i in range(inputs.ndim) ) slab_input = inputs[slab_slices] # Compute crop locations in slab scan_interval = _get_scan_interval( slab_input.shape[2:], roi_size_safe, num_spatial_dims, overlap ) slices = dense_patch_slices(slab_input.shape[2:], roi_size_safe, scan_interval) num_win = len(slices) total_slices = num_win * batch_size # Perform sliding window inference on slab slice_indices = list(range(0, total_slices, sw_batch_size)) for slice_idx in ( tqdm(slice_indices, leave=False, position=1) if progress else slice_indices ): # Get crops from slices slice_range = range(slice_idx, min(slice_idx + sw_batch_size, total_slices)) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] window_data = torch.cat( [ convert_data_type(slab_input[win_slice], torch.Tensor)[0] for win_slice in unravel_slice ] ).to(sw_device) # Compute probabilities and aggregate probabilities = predictor(window_data, *args, **kwargs) if slab_probabilities is None: output_classes = probabilities.shape[1] slab_probabilities = torch.zeros( (batch_size, output_classes) + tuple( image_size[i] if i != outer_dim else patch_size[outer_dim] for i in range(len(image_size)) ), dtype=compute_dtype, device=sw_device, ) probabilities *= importance_map.unsqueeze(0).unsqueeze(0) for slice_idx, win_slice in enumerate(unravel_slice): slab_probabilities[win_slice] += probabilities[ slice_idx : slice_idx + 1 ] slab_weights[win_slice] += importance_map # Apply reduction operation and move partial output to output buffer assert slab_probabilities is not None assert slab_weights is not None copy_size = ( roi_size_safe[outer_dim] if outer_idx == len(outer_indices) - 1 else outer_step_size ) reduction_slices = tuple( slice(None) if i != outer_dim + 2 # account for batch and channel dimensions else slice(None, copy_size) for i in range(slab_probabilities.ndim) ) predictions = reduction_fn( slab_probabilities[reduction_slices] / slab_weights[reduction_slices], dim=reduction_dim, ) output[ tuple( slice(None) if i != outer_dim + 1 # account for batch dimension else slice(outer_dim_idx, outer_dim_idx + copy_size) for i in range(output.ndim) ) ] = predictions last_outer_dim_idx = outer_dim_idx # Crop to original image size output = output[ ..., : orig_image_size[0], : orig_image_size[1], : orig_image_size[2] ] if isinstance(inputs, MetaTensor): return convert_to_dst_type(output, inputs, device=device)[0] return output ```
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from large pool |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from small pool |       0 KB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from large pool |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from small pool |       0 KB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   22486 MB |   22486 MB |   22486 MB |       0 B  |
|       from large pool |   22484 MB |   22484 MB |   22484 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |   12288 KB |    5308 MB |   75658 MB |   75646 MB |
|       from large pool |   12288 KB |    5308 MB |   75654 MB |   75642 MB |
|       from small pool |       0 KB |       1 MB |       4 MB |       4 MB |
|---------------------------------------------------------------------------|
| Allocations           |       2    |      11    |     168    |     166    |
|       from large pool |       2    |      11    |     162    |     160    |
|       from small pool |       0    |       3    |       6    |       6    |
|---------------------------------------------------------------------------|
| Active allocs         |       2    |      11    |     168    |     166    |
|       from large pool |       2    |      11    |     162    |     160    |
|       from small pool |       0    |       3    |       6    |       6    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      11    |      11    |      11    |       0    |
|       from large pool |      10    |      10    |      10    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       1    |       6    |      49    |      48    |
|       from large pool |       1    |       6    |      46    |      45    |
|       from small pool |       0    |       2    |       3    |       3    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

In this example the peak memory usage is reduced in half. Further, it runs currently the outer loop on the third spatial dimension, so if you increase the size of this dimension the peak memory usage stays constant. E.g. 384x384x384 results in an OOM on A6000 with 48GB, but has still 16GB of peak memory usage using the optimized sliding window algorithm.

Current limitations before being able to integrate it into MONAI:

Describe alternatives you've considered Running sliding_window_inference with device='cpu', but it is of course significantly slower, since all computations regarding the aggregation buffer are performed on cpu instead of gpu.

Another approach could be to apply the same two step sliding window approach, but perform all the aggregation on the gpu and move the fully aggregated probabilities of a slab to the huge aggregation buffer in cpu memory. Then we have just a large copy operation instead of cpu-based computations for adding and normalizing probabilities.

Discussion What do you think about this approach and would you like to integrate it into MONAI? In my opinion especially for inference or full image validation during training, where we only need the class with the highest probability, it would be a much more efficient implementation. This would allow remaining memory to be used for complex model computations instead of just a data storage. Maybe it could be further enhanced to also support applying multiple reductions, e.g. argmax and a FP measurement of uncertainty.

wyli commented 1 year ago

thanks for the insights, we recently added a similar idea of buffering with the buffer_dim and buffer_steps parameters: https://github.com/Project-MONAI/MONAI/blob/9c9777751ab4f96e059a6597b9aa7ac6e7ca3b92/monai/inferers/utils.py#L121-L127

https://github.com/Project-MONAI/MONAI/discussions/6157#discussioncomment-5491346

It's available in monai 1.2.0rc5. we haven't explored the reduction_fn idea yet. cc @myron

myron commented 1 year ago

We also added SlidingWindowInfererAdapt class to automatically manage memory without OOM, which you can use as a replacement for SlidingWindowInferer.

razorx89 commented 1 year ago

That is good to know, thanks. But still, a more memory efficient algorithm for just receiving the predicted class index would help increasing inference speed. It seems like the SlidingWindowInfererAdapt just tries different settings until it does not get an OOM anymore. That further increases inference times, especially if the function is only called once (e.g. a predict.py script for a single image).

myron commented 1 year ago

@razorx89 thank you for a great example/issue and the code with evaluation. So, SlidingWindowInfererAdapt() simplifies memory management by attempting to run optimally within GPU budged, and it will use try/except.

You can however run SlidingWindowInferer() directly with "buffered" mode, which will be somewhat similar to your suggestion (if you know ahead of time you have lower gpu mem).

(on dev branch of monai)

import torch

from monai.inferers import SlidingWindowInferer
sliding_inferer = SlidingWindowInferer(roi_size=[128,128,128],  overlap=0.25, buffer_steps=1, device='cpu')

num_classes = 100
data = torch.rand(1, 1, 384, 384, 256, dtype=torch.float32, device='cuda:0')
model = lambda x: x * torch.rand(x.shape[0], num_classes, x.shape[2], x.shape[3], x.shape[4], dtype=x.dtype, device=x.device)

out = sliding_inferer(data, model)

print(torch.cuda.memory_summary())
|===========================================================================|                                                                                   
|                  PyTorch CUDA memory summary, device ID 0                 |                                                                                   
|---------------------------------------------------------------------------|                                                                                   
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |                                                                                   
|===========================================================================|                                                                                   
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from large pool | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from small pool |      0 KiB |      0 MiB |      0 MiB |      0 MiB |
|---------------------------------------------------------------------------|
| Active memory         | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from large pool | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from small pool |      0 KiB |      0 MiB |      0 MiB |      0 MiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   9766 MiB |   9766 MiB |   9766 MiB |      0 B   |
|       from large pool |   9764 MiB |   9764 MiB |   9764 MiB |      0 B   |
|       from small pool |      2 MiB |      2 MiB |      2 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory |      0 B   |  14335 KiB |  14337 KiB |  14337 KiB |
|       from large pool |      0 B   |  12288 KiB |  12288 KiB |  12288 KiB |
|       from small pool |      0 B   |   2047 KiB |   2049 KiB |   2049 KiB |
|---------------------------------------------------------------------------|
| Allocations           |       1    |       6    |     152    |     151    |
|       from large pool |       1    |       6    |     149    |     148    |
|       from small pool |       0    |       3    |       3    |       3    |
|---------------------------------------------------------------------------|
|---------------------------------------------------------------------------|
| Active allocs         |       1    |       6    |     152    |     151    |
|       from large pool |       1    |       6    |     149    |     148    |
|       from small pool |       0    |       3    |       3    |       3    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       7    |       7    |       7    |       0    |
|       from large pool |       6    |       6    |       6    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       2    |       2    |       2    |
|       from large pool |       0    |       1    |       1    |       1    |
|       from small pool |       0    |       1    |       1    |       1    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

this will run inference and intermediate stitching on gpu, and has lower peak-memory than your example. (you can set buffer_steps=2 , to achieve a similar peak memory). in this "buffered" mode, the peak memory shouldn't increase even if your input image size is larger in z-axis (same as you mentioned). PS: for some reason I couldn't run your code (some errors), so if you can please "time it" vs example above (with buffer_steps=1 and buffer_steps=2), we can see if the runtime is much different.

notice here we get the probability (float) output, and we can do a) ensembling b) resampling to invert to the original resolution (if we trained at resampled resolution). Instead if the results is only after argmax then we (mostly) lose these abilities. So the application of your approach seems to focused on a specific use-case when we only do a single model inference at fixed resolution. If it's really much faster than SlidingWindowInferer, then can consider a PR internally. Or you're always welcome to submit a PR too.

razorx89 commented 1 year ago

for some reason I couldn't run your code (some errors), so if you can please "time it" vs example above (with buffer_steps=1 and buffer_steps=2), we can see if the runtime is much different.

Yeah, there were some changes to the utility functions and overlap is now considered to be a tuple.

Here are the timings for the above example (average over 10 iterations, version 1.2.dev2318). There seem to be some additional improvements in sliding_window_inference between v1.1 and dev. My implementation is based on the v1.1 version.

v1.1 - default:     9.420s
dev - default:      7.531s
dev - buffer=1:    10.071s
dev - buffer=2:     9.984s
dev - output=cpu: 203.805s
dev - reduction:    9.111s

However, regarding the memory consumption, the buffered implementation has in my experiments a higher peak memory usage:

dev - default:   16.695 GiB
dev - buffer=1:  23.727 GiB
dev - buffer=2:  27.438 GiB
dev - reduction: 15.633 GiB

Edit: I forgot the device="cpu" for buffered mode, that is why the peak memory usage is higher. Additionally, here are also the results of my implementation with device="cpu".

timings:

dev - buffer=1 cpu:  120.505s
dev - buffer=2 cpu:  109.838s
dev - reduction cpu:  11.553s

peak memory usage:

dev - buffer=1 cpu:   9.523 GiB
dev - buffer=2 cpu:  14.797 GiB
dev - reduction cpu: 15.598 GiB

notice here we get the probability (float) output, and we can do a) ensembling b) resampling to invert to the original resolution (if we trained at resampled resolution).

a) can still be done by wrapping all models into one model, assuming that the crop size is the same for all models in the ensemble.

b) True, but in most MONAI tutorials this is not the case. In most examples, the postprocessing pipeline executes a AsDiscrete transformation and resamples afterwards, which is basically the same sceneario.

https://github.com/Project-MONAI/tutorials/blob/c014b03c0425eddbf2beed5490dc246543ddd2b4/modules/dynunet_pipeline/inferrer.py#L75 https://github.com/Project-MONAI/tutorials/blob/c014b03c0425eddbf2beed5490dc246543ddd2b4/modules/dynunet_pipeline/inferrer.py#L130-L142

These examples only handle resampling in preprocessing and also apply only torch.argmax at the end without postprocessing transforms: https://github.com/Project-MONAI/tutorials/tree/main/3d_segmentation

And resizing huge probability maps (class count > 100) back to original resolution may take forever or even generate OOMs, thus, resampling artifacts may be acceptable.

myron commented 1 year ago

thank you for the response, I can see it's useful for your case, and some other certain cases. I will let other people to comment.