Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.51k stars 1.01k forks source link

DynUNet crashes with DataParallel and DeepSupervision #6442

Open razorx89 opened 1 year ago

razorx89 commented 1 year ago

Describe the bug DynUNet crashes in a torch.nn.DataParallel scenario, since a mutable list is used to get the supervision heads. https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L212-L219 https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L51 This does not work for multiple GPUs in this scenarios, because we end up with tensors in the list having different CUDA devices. The code crashes when stacking the tensors in the list at: https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L271-L275

To Reproduce Run torch.nn.DataParallel(DynUNet(..., deep_supervision=True), device_ids=[0, 1])

Expected behavior DynUNet forward should be threadsafe. I know that DistributedDataParallel is superior and would solve the problem, however, it should still work by correctly storing results from block return values instead of using a "global" mutable list.

Environment

================================
Printing MONAI config...
================================
MONAI version: 1.1.0
Numpy version: 1.24.2
Pytorch version: 1.14.0a0+410ce96
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /usr/local/lib/python3.8/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.2
scikit-image version: 0.19.3
Pillow version: 9.4.0
Tensorboard version: 2.12.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.15.0a0
tqdm version: 4.64.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.4
pandas version: 1.5.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 20.04.5 LTS
Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.29
Processor: x86_64
Machine: x86_64
Python version: 3.8.10
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 48
Num logical CPUs: 48
Num usable CPUs: 48
CPU usage (%): [4.9, 5.4, 4.4, 4.4, 4.9, 4.4, 4.4, 4.4, 5.3, 4.9, 4.9, 4.9, 4.9, 5.3, 4.9, 5.3, 4.9, 4.4, 6.3, 4.9, 4.9, 4.4, 4.4, 5.3, 4.9, 4.9, 4.4, 4.9, 4.4, 4.9, 4.4, 4.9, 5.3, 4.9, 4.4, 4.9, 4.4, 4.9, 4.4, 4.9, 4.9, 4.4, 4.9, 4.4, 4.9, 4.4, 4.9, 99.5]
CPU freq. (MHz): 1646
Load avg. in last 1, 5, 15 mins (%): [0.2, 1.2, 6.7]
Disk usage (%): 60.9
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.8
Available memory (GB): 991.8
Used memory (GB): 9.7

================================
Printing GPU config...
================================
Num GPUs: 2
Has CUDA: True
CUDA version: 11.8
cuDNN enabled: True
cuDNN version: 8700
Current device: 0
Library compiled for CUDA architectures: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'compute_90']
GPU 0 Name: NVIDIA RTX A6000
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 47.5
GPU 0 CUDA capability (maj.min): 8.6
GPU 1 Name: NVIDIA RTX A6000
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 84
GPU 1 Total memory (GB): 47.5
GPU 1 CUDA capability (maj.min): 8.6
Nic-Ma commented 1 year ago

Hi @yiheng-wang-nv ,

Is DynUNet thread-safe?

Thanks.

yiheng-wang-nv commented 1 year ago

Hi @razorx89 , could you provide detailed code that can reproduce the crash issue?

I did a simple test with the following code and did not meet error:

import torch
from monai.networks.nets import DynUNet

kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    deep_supervision=True,
    deep_supr_num=1,
)

net = torch.nn.DataParallel(net, device_ids=[0, 1])
razorx89 commented 1 year ago

Here you go, it crashes on the second batch:

import torch
from monai.networks.nets import DynUNet

kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    deep_supervision=True,
    deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])

x = torch.randn(16, 1, 64, 64, 64)
x = x.cuda()

for i in range(10):
    print("Batch", i)
    net(x)
Batch 0
Batch 1
Traceback (most recent call last):
  File "dynunet_bug.py", line 27, in <module>
    net(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 601, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/monai/networks/nets/dynunet.py", line 273, in forward
    return torch.stack(out_all, dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_cat)
yiheng-wang-nv commented 1 year ago

Hi @razorx89 , thanks and I reproduced the issue. I submitted a PR in #6444 and it is tested with the code you posted. Could you please help to review that PR?

razorx89 commented 1 year ago

Thanks, it does not crash anymore. However, I don't think that the returned supervision heads are correct. In my experiments it does not learn at all (I cannot provide an example of this). Revisiting the posted code line above: https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L51 Add a print statement at this location:

print(x.device, id(self.heads), self.index)
cuda:0 140709361877056 1
cuda:1 140709361877056 1

You will see that both replicas write to the same list instance. Thus, both replicas will return a tensor with the same content (plus/minus race conditions), regardless of the input of the replica.

razorx89 commented 1 year ago

Any updates on this, @yiheng-wang-nv? I would love to see this issue reopened, since it is still not working correctly.

yiheng-wang-nv commented 1 year ago

Hi @ericspod , could you please help to give some suggestions here? In multi-thread (DataParallel) case, it seems the interpolate way (https://github.com/Project-MONAI/MONAI/blob/c2a9a31beb22e1d3321016be7e68e0875cf2a8ad/monai/networks/nets/dynunet.py#L272) to concat output in different scales cannot work. I remember we did this kind of changes in order to support TorchScript.

yiheng-wang-nv commented 1 year ago

Add an error example, when set batch size to 9 and using 2 GPUs, torch.nn.DataParallel will scatter the input tensor into two batch size = 4 and batch size = 5 tensors, and then error will happen

import torch
from monai.networks.nets import DynUNet

kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    deep_supervision=True,
    deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])

x = torch.randn(9, 1, 64, 64, 64)
x = x.cuda()

for i in range(10):
    print("Batch", i)
    net(x)
yiheng-wang-nv commented 1 year ago

Hi @wyli @Nic-Ma @ericspod , I'm not sure if this issue can be fixed soon, what do you think of mention the limitation (of dataparallel support) in docstrings first?

wyli commented 1 year ago

it seems that making device-specific heads works fine https://github.com/Project-MONAI/MONAI/pull/6484 but it's a bit hacky..

ericspod commented 1 year ago

I'm not sure device-specific heads are going to be enough if DataParallel is doing things with multiple threads. It's a race condition when two or more threads is accessing and modifying the self.heads dictionary. In normal Python we'd use something like local to have thread-specific data or use locks to synchronize access to a shared object, but I don't know about Torchscript compatibility for that. Perhaps it's best to just say that DataParallel is compatible since DistributedDataParallel could be used in place.

wyli commented 1 year ago

yes, at each forward pass, new threads are created https://github.com/pytorch/pytorch/blob/0bf9722a3af5c00125ccb557a3618f11e5413236/torch/nn/parallel/parallel_apply.py#L73 it's not easy to get a generic solution and compatible with torchscript.

I guess the main use case of DataParallel is to quickly try large batch sizes when there are multiple gpus. to properly leverage the gpus to accelerate training DistributedDataParallel is necessary

razorx89 commented 1 year ago

Maybe my two cents on why I am using DataParallel: I am working on a single node multi gpu system. During training I am running every n-th epoch an evaluation on the validation set, where I am using a sliding window inferer on full size CT images. During training I am distributing cropped images from multiple CTs across multiple GPUs and aggregate the losses before the optimization step. During validation I am computing metrics on a single CT at a time using a sliding window inferer with a DataParallel model. So sw_batch_size crops from the sliding window algorithm get distributed across the GPUs. I cannot batch the CT images since they all have a different number of slices (depth), or I have to use padding which may increase memory usage and/or runtime (e.g. mixing whole body with abdomen CTs). This pattern is as far as I know not easily implementable using DistributedDataParallel.

wyli commented 1 year ago

thanks, perhaps these changes work for your use case https://github.com/Project-MONAI/MONAI/pull/6484

chezhia commented 2 weeks ago

Hi @razorx89, I ran into the same issue of Deep supervision not working in DataParallel mode. The solution I came up with is to change these lines in the DynUNet (monai/networks/nets/dynunet.py) definition and it seems to work for my test case. Can you try this:

def forward(self, x):
        out = self.skip_layers(x)
        out = self.output_block(out)
        if self.training and self.deep_supervision:
            out_all = [out]  # 'out' should be on 'cuda:0' by default
            for feature_map in self.heads:
                # Interpolate feature map to the size of 'out' and ensure device consistency if necessary
                interpolated_map = interpolate(feature_map, out.shape[2:]).to(out.device)
                out_all.append(interpolated_map)
            return torch.stack(out_all, dim=1)  # This should not cause device mismatch errors
        return out