open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.83k stars 1.63k forks source link

Error in CudnnConvolutionBackward on Conv3d #980

Closed baibaidj closed 2 years ago

baibaidj commented 3 years ago

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The unexpected results still exist in the latest version.

Describe the Issue I'm training a 3D segmentation network for organs using UPernet 3D with Conv3D, and after starting training for a few iterations, an CUDA error ("an illegal memory access was encountered") is invoked. I expect that it can keep training without this error.

Reproduction

  1. What command, code, or script did you run? training on a single gpu with the following config:
    
    # data settings
    patch_size = (256, 256, 160)
    data = dict(samples_per_gpu=3, workers_per_gpu= 9)

model settings

conv_cfg = dict(type = 'Conv3d') norm_cfg = dict(type='BN3d', requires_grad=True) #Sync base_channels = 24 # bs2, chn24 16G; bs2, chn48 failed due to illegal memory access fpn_chn = int(512 base_channels/96) model = dict( type='EncoderDecoderMonai', pretrained= None, backbone=dict( type='SwinTransformer3d', in_chans=1, embed_dim=base_channels, depths=[2, 2, 6, 2], # 2,4,12,4: bs2, 21G num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), use_checkpoint=False), neck=dict( type='UPerNeck3D', in_channels=[base_channels (2* i) for i in range(4)], in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels= fpn_chn, conv_cfg = conv_cfg, norm_cfg=norm_cfg, align_corners=False,), decode_head=dict( type='FCNHead3D', in_channels=[fpn_chn] 4, in_index=(0, 1, 2, 3), channels=fpn_chn, input_transform='resize_concat', kernel_size=3, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=num_classes, conv_cfg = conv_cfg, norm_cfg=norm_cfg, align_corners=False, verbose = False, loss_decode =dict( type='ComboLossMed', loss_weight=(1.0, 0.6), num_classes = num_classes, class_weight = (0.8, 1.1, 1.0, 1.0), verbose = False, ), ), auxiliary_head=dict( type='FCNHead3D', in_channels=fpn_chn, in_index=0, channels=fpn_chn//2, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=num_classes, conv_cfg = conv_cfg, norm_cfg=norm_cfg, align_corners=False, loss_decode =dict(
type='ComboLossMed', loss_weight=(1.0 0.4, 0.6 0.4), num_classes = num_classes, class_weight = (0.8, 1.1, 1.0, 1.0), verbose = False ), ),

model training and testing settings

train_cfg = dict(),
test_cfg = dict(mode='slide', roi_size = patch_size, sw_batch_size = 2,
                blend_mode = 'gaussian' , overlap=0.5, sigma_scale = 0.125, # 'gaussian or constant
                padding_mode='constant'), 
    )

optimizer

AdamW optimizer, no weight decay for position embedding & layer norm in backbone

optimizer = dict(type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.001, paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 'relative_position_bias_table': dict(decay_mult=0.), 'norm': dict(decay_mult=0.)})) lr_config = dict(policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512., distributed = False, grad_clip = dict(max_norm = 8, norm_type = 2))


2. Did you make any modifications on the code? Did you understand what you have modified?
I've adapted the 2d version of networks, including backbone, neck, head and loss to 3d version. 

**Environment**
2021-04-23 18:21:20,969 - mmseg - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.8.5 (default, Sep  4 2020, 07:30:14) [GCC 7.3.0]
CUDA available: True
GPU 0: Tesla V100-PCIE-32GB
CUDA_HOME: /usr/local/cuda
NVCC: Build cuda_11.1.TC455_06.29069683_0
GCC: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
PyTorch: 1.7.1+cu110
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.0
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80
  - CuDNN 8.0.5
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wal
l -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-str
ict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-may
be-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_
NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

TorchVision: 0.8.2+cu110
OpenCV: 4.5.1
MMCV: 1.2.7
MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 11.0
MMSegmentation: 0.11.0+
------------------------------------------------------------

**Error traceback**
If applicable, paste the error traceback here.

```none
2021-04-23 18:21:30,605 - mmseg - INFO - workflow: [('train', 1)], max: 400 epochs
2021-04-23 18:22:31,951 - mmseg - INFO - Iter [8/40400] lr: 4.667e-07, eta: 3 days, 13:56:06, time: 7.659, data_time: 2.180, memory: 17613, decode.loss_seg: 1.6604, decode.acc_seg: 29.2102, aux.loss_seg: 0.6742, aux.acc_seg: 10.6402, loss: 2.3346, grad_$
orm: 3.2914
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnConvolutionBackward. Traceback of forward call that caused the error:
  File "tools/train.py", line 176, in <module>
    main()
  File "tools/train.py", line 164, in main
    train_segmentor(
  File "/home/whos/git/mmseg4med/mmseg/apis/train.py", line 120, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 29, in run_iter
    outputs = self.model.train_step(data_batch, self.optimizer,
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 67, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/home/whos/git/mmseg4med/mmseg/models/segmentors/base.py", line 152, in train_step
    losses = self(**data_batch)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 110, in new_func
    output = old_func(*new_args, **new_kwargs)
  File "/home/whos/git/mmseg4med/mmseg/models/segmentors/base.py", line 122, in forward
    return self.forward_train(img, img_metas, **kwargs)
  File "/home/whos/git/mmseg4med/mmseg/models/segmentors/encoder_decoder_monai.py", line 102, in forward_train
    loss_decode = self._decode_head_forward_train(x, img_metas,
  File "/home/whos/git/mmseg4med/mmseg/models/segmentors/encoder_decoder.py", line 100, in _decode_head_forward_train
    loss_decode = self.decode_head.forward_train(x, img_metas,
  File "/home/whos/git/mmseg4med/mmseg/models/decode_heads/fcn_head_3d.py", line 146, in forward_train
    seg_logits = self.forward(inputs)
  File "/home/whos/git/mmseg4med/mmseg/models/decode_heads/fcn_head_3d.py", line 93, in forward
    feat_map = self.convs(x) if self.num_convs > 0 else x
  File "/home/whos/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/cnn/bricks/conv_module.py", line 193, in forward
    x = self.conv(x)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/cnn/bricks/wrappers.py", line 79, in forward
    return super().forward(x)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 572, in forward
    return F.conv3d(input, self.weight, self.bias, self.stride,
 (function _print_stack)
Traceback (most recent call last):
  File "tools/train.py", line 176, in <module>
    main()
  File "tools/train.py", line 164, in main
    train_segmentor(
  File "/home/whos/git/mmseg4med/mmseg/apis/train.py", line 120, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train
    self.call_hook('after_train_iter')
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/base_runner.py", line 308, in call_hook
    getattr(hook, fn_name)(self)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/mmcv/runner/hooks/optimizer.py", line 130, in after_train_iter
    scaled_loss.backward()
  File "/home/whos/miniconda3/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/whos/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: CUDA error: an illegal memory access was encountered

Bug fix not yet.

zhouzaida commented 2 years ago

related issue: https://github.com/pytorch/pytorch/issues/48945

kei971031 commented 1 year ago

RuntimeError: CUDA error: an illegal memory access was encountered

This error when converting Conv2d to Conv3d is, in my experience, another version of OOM. Appears when trying to allocate a tensor that is larger than the GPU's total memory size. Test by reducing the batch size or the number of channels in the model.

Also make sure torch.backends.cudnn.enabled is enabled conv3d memory issue