Open razorx89 opened 1 year ago
Hi @yiheng-wang-nv ,
Is DynUNet thread-safe?
Thanks.
Hi @razorx89 , could you provide detailed code that can reproduce the crash issue?
I did a simple test with the following code and did not meet error:
import torch
from monai.networks.nets import DynUNet
kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
net = DynUNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
deep_supervision=True,
deep_supr_num=1,
)
net = torch.nn.DataParallel(net, device_ids=[0, 1])
Here you go, it crashes on the second batch:
import torch
from monai.networks.nets import DynUNet
kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
net = DynUNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
deep_supervision=True,
deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])
x = torch.randn(16, 1, 64, 64, 64)
x = x.cuda()
for i in range(10):
print("Batch", i)
net(x)
Batch 0
Batch 1
Traceback (most recent call last):
File "dynunet_bug.py", line 27, in <module>
net(x)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 601, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/monai/networks/nets/dynunet.py", line 273, in forward
return torch.stack(out_all, dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_cat)
Hi @razorx89 , thanks and I reproduced the issue. I submitted a PR in #6444 and it is tested with the code you posted. Could you please help to review that PR?
Thanks, it does not crash anymore. However, I don't think that the returned supervision heads are correct. In my experiments it does not learn at all (I cannot provide an example of this). Revisiting the posted code line above: https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L51 Add a print statement at this location:
print(x.device, id(self.heads), self.index)
cuda:0 140709361877056 1
cuda:1 140709361877056 1
You will see that both replicas write to the same list instance. Thus, both replicas will return a tensor with the same content (plus/minus race conditions), regardless of the input of the replica.
Any updates on this, @yiheng-wang-nv? I would love to see this issue reopened, since it is still not working correctly.
Hi @ericspod , could you please help to give some suggestions here? In multi-thread (DataParallel) case, it seems the interpolate way (https://github.com/Project-MONAI/MONAI/blob/c2a9a31beb22e1d3321016be7e68e0875cf2a8ad/monai/networks/nets/dynunet.py#L272) to concat output in different scales cannot work. I remember we did this kind of changes in order to support TorchScript.
Add an error example, when set batch size to 9 and using 2 GPUs, torch.nn.DataParallel
will scatter the input tensor into two batch size = 4 and batch size = 5 tensors, and then error will happen
import torch
from monai.networks.nets import DynUNet
kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
net = DynUNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
deep_supervision=True,
deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])
x = torch.randn(9, 1, 64, 64, 64)
x = x.cuda()
for i in range(10):
print("Batch", i)
net(x)
Hi @wyli @Nic-Ma @ericspod , I'm not sure if this issue can be fixed soon, what do you think of mention the limitation (of dataparallel support) in docstrings first?
it seems that making device-specific heads works fine https://github.com/Project-MONAI/MONAI/pull/6484 but it's a bit hacky..
I'm not sure device-specific heads are going to be enough if DataParallel
is doing things with multiple threads. It's a race condition when two or more threads is accessing and modifying the self.heads
dictionary. In normal Python we'd use something like local to have thread-specific data or use locks to synchronize access to a shared object, but I don't know about Torchscript compatibility for that. Perhaps it's best to just say that DataParallel
is compatible since DistributedDataParallel
could be used in place.
yes, at each forward pass, new threads are created https://github.com/pytorch/pytorch/blob/0bf9722a3af5c00125ccb557a3618f11e5413236/torch/nn/parallel/parallel_apply.py#L73 it's not easy to get a generic solution and compatible with torchscript.
I guess the main use case of DataParallel
is to quickly try large batch sizes when there are multiple gpus. to properly leverage the gpus to accelerate training DistributedDataParallel
is necessary
Maybe my two cents on why I am using DataParallel
: I am working on a single node multi gpu system. During training I am running every n-th epoch an evaluation on the validation set, where I am using a sliding window inferer on full size CT images. During training I am distributing cropped images from multiple CTs across multiple GPUs and aggregate the losses before the optimization step. During validation I am computing metrics on a single CT at a time using a sliding window inferer with a DataParallel
model. So sw_batch_size
crops from the sliding window algorithm get distributed across the GPUs. I cannot batch the CT images since they all have a different number of slices (depth), or I have to use padding which may increase memory usage and/or runtime (e.g. mixing whole body with abdomen CTs). This pattern is as far as I know not easily implementable using DistributedDataParallel
.
thanks, perhaps these changes work for your use case https://github.com/Project-MONAI/MONAI/pull/6484
Hi @razorx89, I ran into the same issue of Deep supervision not working in DataParallel
mode. The solution I came up with is to change these lines in the DynUNet (monai/networks/nets/dynunet.py) definition and it seems to work for my test case. Can you try this:
def forward(self, x):
out = self.skip_layers(x)
out = self.output_block(out)
if self.training and self.deep_supervision:
out_all = [out] # 'out' should be on 'cuda:0' by default
for feature_map in self.heads:
# Interpolate feature map to the size of 'out' and ensure device consistency if necessary
interpolated_map = interpolate(feature_map, out.shape[2:]).to(out.device)
out_all.append(interpolated_map)
return torch.stack(out_all, dim=1) # This should not cause device mismatch errors
return out
Describe the bug DynUNet crashes in a
torch.nn.DataParallel
scenario, since a mutable list is used to get the supervision heads. https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L212-L219 https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L51 This does not work for multiple GPUs in this scenarios, because we end up with tensors in the list having different CUDA devices. The code crashes when stacking the tensors in the list at: https://github.com/Project-MONAI/MONAI/blob/5f344cc4c0dc884e1a8273a9073346dc1703f85d/monai/networks/nets/dynunet.py#L271-L275To Reproduce Run
torch.nn.DataParallel(DynUNet(..., deep_supervision=True), device_ids=[0, 1])
Expected behavior DynUNet forward should be threadsafe. I know that
DistributedDataParallel
is superior and would solve the problem, however, it should still work by correctly storing results from block return values instead of using a "global" mutable list.Environment