AIR-THU / DAIR-V2X

Apache License 2.0
428 stars 65 forks source link

Error occured when reimplementing Mvxnet on sv3d-inf dataset #9

Closed shb9793 closed 2 years ago

shb9793 commented 2 years ago

When I ran the command in mmdection3d environment as follows:

python tools/train.py /HOME/scz3687/run/openmmlab-0.17.1/dair-v2x/configs/sv3d-inf/mvxnet/trainval_config.py --work-dir /HOME/scz3687/run/openmmlab-0.17.1/dair-v2x/work_dirs/sv3d_inf_mvxnet

The error was encountered as follows:

2022-07-27 10:18:51,994 - mmdet - INFO - workflow: [('train', 1)], max: 40 epochs
/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/data/run01/scz3687/openmmlab-0.17.1/mmdetection3d/mmdet3d/models/fusion_layers/coord_transform.py:34: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  if 'pcd_rotation' in img_meta else torch.eye(
2022-07-27 10:19:13,395 - mmdet - INFO - Epoch [1][50/10084]    lr: 4.323e-04, eta: 1 day, 18:54:39, time: 0.383, data_time: 0.073, memory: 3474, loss_cls: 0.9940, loss_bbox: 3.7010, loss_dir: 0.1496, loss: 4.8445, grad_norm: 202.8270
2022-07-27 10:19:25,671 - mmdet - INFO - Epoch [1][100/10084]   lr: 5.673e-04, eta: 1 day, 11:12:17, time: 0.246, data_time: 0.002, memory: 3474, loss_cls: 0.8204, loss_bbox: 1.6195, loss_dir: 0.1448, loss: 2.5846, grad_norm: 31.3928
2022-07-27 10:19:37,977 - mmdet - INFO - Epoch [1][150/10084]   lr: 7.023e-04, eta: 1 day, 8:39:17, time: 0.246, data_time: 0.002, memory: 3474, loss_cls: 0.8210, loss_bbox: 1.4106, loss_dir: 0.1391, loss: 2.3707, grad_norm: 25.7827
2022-07-27 10:19:50,285 - mmdet - INFO - Epoch [1][200/10084]   lr: 8.373e-04, eta: 1 day, 7:22:47, time: 0.246, data_time: 0.003, memory: 3482, loss_cls: 0.8008, loss_bbox: 1.4836, loss_dir: 0.1374, loss: 2.4218, grad_norm: 24.7188
2022-07-27 10:20:02,640 - mmdet - INFO - Epoch [1][250/10084]   lr: 9.723e-04, eta: 1 day, 6:38:03, time: 0.247, data_time: 0.003, memory: 3482, loss_cls: 0.7371, loss_bbox: 1.5554, loss_dir: 0.1419, loss: 2.4344, grad_norm: 21.9774
2022-07-27 10:20:14,775 - mmdet - INFO - Epoch [1][300/10084]   lr: 1.107e-03, eta: 1 day, 6:03:15, time: 0.243, data_time: 0.002, memory: 3482, loss_cls: 0.6303, loss_bbox: 1.4387, loss_dir: 0.1448, loss: 2.2138, grad_norm: 21.0055
/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/hooks/optimizer.py:31: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.
  return clip_grad.clip_grad_norm_(params, **self.grad_clip)
2022-07-27 10:20:27,419 - mmdet - INFO - Epoch [1][350/10084]   lr: 1.242e-03, eta: 1 day, 5:48:06, time: 0.253, data_time: 0.003, memory: 3503, loss_cls: nan, loss_bbox: nan, loss_dir: nan, loss: nan, grad_norm: nan
Traceback (most recent call last):
  File "tools/train.py", line 225, in <module>
    main()
  File "tools/train.py", line 221, in main
    meta=meta)
  File "/data/run01/scz3687/openmmlab-0.17.1/mmdetection3d/mmdet3d/apis/train.py", line 35, in train_model
    meta=meta)
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/mmdet/apis/train.py", line 170, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train
    self.call_hook('after_train_iter')
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/base_runner.py", line 307, in call_hook
    getattr(hook, fn_name)(self)
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/hooks/optimizer.py", line 35, in after_train_iter
    runner.outputs['loss'].backward()
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/autograd/__init__.py", line 149, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
  File "/HOME/scz3687/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/autograd/function.py", line 87, in apply
    return self._forward_cls.backward(self, *args)  # type: ignore[attr-defined]
  File "/data/run01/scz3687/openmmlab-0.17.1/mmdetection3d/mmdet3d/ops/voxel/scatter_points.py", line 46, in backward
    voxel_points_count, ctx.reduce_type)
RuntimeError: CUDA error: an illegal memory access was encountered
shb9793 commented 2 years ago

I have tried to change lr=0.003 (as default in config) to lr=0.000375 (8GPUs to 1GPU) .But It didn't work as before.

Noticed that the output file contained some information as follows:

2022-07-27 10:18:51,419 - mmdet - INFO - load checkpoint from https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth
2022-07-27 10:18:51,420 - mmdet - INFO - Use load_from_http loader
2022-07-27 10:18:51,989 - mmdet - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: img_rpn_head.rpn_conv.weight, img_rpn_head.rpn_conv.bias, img_rpn_head.rpn_cls.weight, img_rpn_head.rpn_cls.bias, img_rpn_head.rpn_reg.weight, img_rpn_head.rpn_reg.bias, img_bbox_head.fc_cls.weight, img_bbox_head.fc_cls.bias, img_bbox_head.fc_reg.weight, img_bbox_head.fc_reg.bias, img_bbox_head.shared_fcs.0.weight, img_bbox_head.shared_fcs.0.bias, img_bbox_head.shared_fcs.1.weight, img_bbox_head.shared_fcs.1.bias

missing keys in source state_dict: pts_voxel_encoder.vfe_layers.0.0.weight, pts_voxel_encoder.vfe_layers.0.1.weight, pts_voxel_encoder.vfe_layers.0.1.bias, pts_voxel_encoder.vfe_layers.0.1.running_mean, pts_voxel_encoder.vfe_layers.0.1.running_var, pts_voxel_encoder.vfe_layers.1.0.weight, pts_voxel_encoder.vfe_layers.1.1.weight, pts_voxel_encoder.vfe_layers.1.1.bias, pts_voxel_encoder.vfe_layers.1.1.running_mean, pts_voxel_encoder.vfe_layers.1.1.running_var, pts_voxel_encoder.fusion_layer.lateral_convs.0.conv.weight, pts_voxel_encoder.fusion_layer.lateral_convs.0.conv.bias, pts_voxel_encoder.fusion_layer.lateral_convs.1.conv.weight, pts_voxel_encoder.fusion_layer.lateral_convs.1.conv.bias, pts_voxel_encoder.fusion_layer.lateral_convs.2.conv.weight, pts_voxel_encoder.fusion_layer.lateral_convs.2.conv.bias, pts_voxel_encoder.fusion_layer.lateral_convs.3.conv.weight, pts_voxel_encoder.fusion_layer.lateral_convs.3.conv.bias, pts_voxel_encoder.fusion_layer.lateral_convs.4.conv.weight, pts_voxel_encoder.fusion_layer.lateral_convs.4.conv.bias, pts_voxel_encoder.fusion_layer.img_transform.0.weight, pts_voxel_encoder.fusion_layer.img_transform.0.bias, pts_voxel_encoder.fusion_layer.img_transform.1.weight, pts_voxel_encoder.fusion_layer.img_transform.1.bias, pts_voxel_encoder.fusion_layer.img_transform.1.running_mean, pts_voxel_encoder.fusion_layer.img_transform.1.running_var, pts_voxel_encoder.fusion_layer.pts_transform.0.weight, pts_voxel_encoder.fusion_layer.pts_transform.0.bias, pts_voxel_encoder.fusion_layer.pts_transform.1.weight, pts_voxel_encoder.fusion_layer.pts_transform.1.bias, pts_voxel_encoder.fusion_layer.pts_transform.1.running_mean, pts_voxel_encoder.fusion_layer.pts_transform.1.running_var, pts_middle_encoder.conv_input.0.weight, pts_middle_encoder.conv_input.1.weight, pts_middle_encoder.conv_input.1.bias, pts_middle_encoder.conv_input.1.running_mean, pts_middle_encoder.conv_input.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer1.0.0.weight, pts_middle_encoder.encoder_layers.encoder_layer1.0.1.weight, pts_middle_encoder.encoder_layers.encoder_layer1.0.1.bias, pts_middle_encoder.encoder_layers.encoder_layer1.0.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer1.0.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer2.0.0.weight, pts_middle_encoder.encoder_layers.encoder_layer2.0.1.weight, pts_middle_encoder.encoder_layers.encoder_layer2.0.1.bias, pts_middle_encoder.encoder_layers.encoder_layer2.0.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer2.0.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer2.1.0.weight, pts_middle_encoder.encoder_layers.encoder_layer2.1.1.weight, pts_middle_encoder.encoder_layers.encoder_layer2.1.1.bias, pts_middle_encoder.encoder_layers.encoder_layer2.1.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer2.1.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer2.2.0.weight, pts_middle_encoder.encoder_layers.encoder_layer2.2.1.weight, pts_middle_encoder.encoder_layers.encoder_layer2.2.1.bias, pts_middle_encoder.encoder_layers.encoder_layer2.2.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer2.2.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer3.0.0.weight, pts_middle_encoder.encoder_layers.encoder_layer3.0.1.weight, pts_middle_encoder.encoder_layers.encoder_layer3.0.1.bias, pts_middle_encoder.encoder_layers.encoder_layer3.0.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer3.0.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer3.1.0.weight, pts_middle_encoder.encoder_layers.encoder_layer3.1.1.weight, pts_middle_encoder.encoder_layers.encoder_layer3.1.1.bias, pts_middle_encoder.encoder_layers.encoder_layer3.1.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer3.1.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer3.2.0.weight, pts_middle_encoder.encoder_layers.encoder_layer3.2.1.weight, pts_middle_encoder.encoder_layers.encoder_layer3.2.1.bias, pts_middle_encoder.encoder_layers.encoder_layer3.2.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer3.2.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer4.0.0.weight, pts_middle_encoder.encoder_layers.encoder_layer4.0.1.weight, pts_middle_encoder.encoder_layers.encoder_layer4.0.1.bias, pts_middle_encoder.encoder_layers.encoder_layer4.0.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer4.0.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer4.1.0.weight, pts_middle_encoder.encoder_layers.encoder_layer4.1.1.weight, pts_middle_encoder.encoder_layers.encoder_layer4.1.1.bias, pts_middle_encoder.encoder_layers.encoder_layer4.1.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer4.1.1.running_var, pts_middle_encoder.encoder_layers.encoder_layer4.2.0.weight, pts_middle_encoder.encoder_layers.encoder_layer4.2.1.weight, pts_middle_encoder.encoder_layers.encoder_layer4.2.1.bias, pts_middle_encoder.encoder_layers.encoder_layer4.2.1.running_mean, pts_middle_encoder.encoder_layers.encoder_layer4.2.1.running_var, pts_middle_encoder.conv_out.0.weight, pts_middle_encoder.conv_out.1.weight, pts_middle_encoder.conv_out.1.bias, pts_middle_encoder.conv_out.1.running_mean, pts_middle_encoder.conv_out.1.running_var, pts_backbone.blocks.0.0.weight, pts_backbone.blocks.0.1.weight, pts_backbone.blocks.0.1.bias, pts_backbone.blocks.0.1.running_mean, pts_backbone.blocks.0.1.running_var, pts_backbone.blocks.0.3.weight, pts_backbone.blocks.0.4.weight, pts_backbone.blocks.0.4.bias, pts_backbone.blocks.0.4.running_mean, pts_backbone.blocks.0.4.running_var, pts_backbone.blocks.0.6.weight, pts_backbone.blocks.0.7.weight, pts_backbone.blocks.0.7.bias, pts_backbone.blocks.0.7.running_mean, pts_backbone.blocks.0.7.running_var, pts_backbone.blocks.0.9.weight, pts_backbone.blocks.0.10.weight, pts_backbone.blocks.0.10.bias, pts_backbone.blocks.0.10.running_mean, pts_backbone.blocks.0.10.running_var, pts_backbone.blocks.0.12.weight, pts_backbone.blocks.0.13.weight, pts_backbone.blocks.0.13.bias, pts_backbone.blocks.0.13.running_mean, pts_backbone.blocks.0.13.running_var, pts_backbone.blocks.0.15.weight, pts_backbone.blocks.0.16.weight, pts_backbone.blocks.0.16.bias, pts_backbone.blocks.0.16.running_mean, pts_backbone.blocks.0.16.running_var, pts_backbone.blocks.1.0.weight, pts_backbone.blocks.1.1.weight, pts_backbone.blocks.1.1.bias, pts_backbone.blocks.1.1.running_mean, pts_backbone.blocks.1.1.running_var, pts_backbone.blocks.1.3.weight, pts_backbone.blocks.1.4.weight, pts_backbone.blocks.1.4.bias, pts_backbone.blocks.1.4.running_mean, pts_backbone.blocks.1.4.running_var, pts_backbone.blocks.1.6.weight, pts_backbone.blocks.1.7.weight, pts_backbone.blocks.1.7.bias, pts_backbone.blocks.1.7.running_mean, pts_backbone.blocks.1.7.running_var, pts_backbone.blocks.1.9.weight, pts_backbone.blocks.1.10.weight, pts_backbone.blocks.1.10.bias, pts_backbone.blocks.1.10.running_mean, pts_backbone.blocks.1.10.running_var, pts_backbone.blocks.1.12.weight, pts_backbone.blocks.1.13.weight, pts_backbone.blocks.1.13.bias, pts_backbone.blocks.1.13.running_mean, pts_backbone.blocks.1.13.running_var, pts_backbone.blocks.1.15.weight, pts_backbone.blocks.1.16.weight, pts_backbone.blocks.1.16.bias, pts_backbone.blocks.1.16.running_mean, pts_backbone.blocks.1.16.running_var, pts_neck.deblocks.0.0.weight, pts_neck.deblocks.0.1.weight, pts_neck.deblocks.0.1.bias, pts_neck.deblocks.0.1.running_mean, pts_neck.deblocks.0.1.running_var, pts_neck.deblocks.1.0.weight, pts_neck.deblocks.1.1.weight, pts_neck.deblocks.1.1.bias, pts_neck.deblocks.1.1.running_mean, pts_neck.deblocks.1.1.running_var, pts_bbox_head.conv_cls.weight, pts_bbox_head.conv_cls.bias, pts_bbox_head.conv_reg.weight, pts_bbox_head.conv_reg.bias, pts_bbox_head.conv_dir_cls.weight, pts_bbox_head.conv_dir_cls.bias
shb9793 commented 2 years ago

Maybe the original pretrained file of mvxnet on Kitti is not suitable for DAIR-V2X-I dataset.

So I tried to comment the following code in DAIR-V2X/configs/sv3d-inf/mvxnet/trainval_config.py : https://github.com/AIR-THU/DAIR-V2X/blob/129fad7030c3ec95a31eee03919e2c6ce3b47fdc/configs/sv3d-inf/mvxnet/trainval_config.py#L28

And add load_from = None under this line.

Finally, the training process runs smoothly.