open-mmlab / mmrotate

OpenMMLab Rotated Object Detection Toolbox and Benchmark
https://mmrotate.readthedocs.io/en/latest/
Apache License 2.0
1.84k stars 544 forks source link

[Bug] Error in backward() when using psc. #919

Open gbdjxgp opened 1 year ago

gbdjxgp commented 1 year ago

Prerequisite

Task

I have modified the scripts/configs, or I'm working on my own tasks/models/datasets.

Branch

1.x branch https://github.com/open-mmlab/mmrotate/tree/1.x

Environment

sys.platform: win32 Python: 3.8.17 (default, Jul 5 2023, 20:44:21) [MSC v.1916 64 bit (AMD64)] CUDA available: True numpy_random_seed: 2147483648 GPU 0: NVIDIA GeForce RTX 3090 CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.6 NVCC: Cuda compilation tools, release 11.6, V11.6.124 MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.29.30151 版 GCC: n/a PyTorch: 2.0.1 PyTorch compiling details: PyTorch built with:

TorchVision: 0.15.2 OpenCV: 4.8.0 MMEngine: 0.8.1 MMRotate: 1.0.0rc1+

Reproduces the problem - code sample

dataset:(It's a simple DOTA format dataset that only contains one image in the training set.)fair_test.zip

  1. Using the model config configs/psc/rotated-retinanet-rbox-le90_r50_fpn_psc-dual_amp-1x_dota.py
  2. Adding shield_reg_angle = True in bbox_head
  3. Changing batch_size=2; drop_last=False in train_dataloader
  4. Make sure that batch_data is less than 2 while training.

Reproduces the problem - command or script

python tools\train.py --configs/psc/rotated-retinanet-rbox-le90_r50_fpn_psc-dual_amp-1x_dota.py

Reproduces the problem - error message

Traceback (most recent call last): File "C:\code\mmrotate\tools\train.py", line 122, in main() File "C:\code\mmrotate\tools\train.py", line 118, in main runner.train() File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\runner\runner.py", line 1735, in train model = self.train_loop.run() # type: ignore File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\runner\loops.py", line 96, in run self.run_epoch() File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\runner\loops.py", line 112, in run_epoch self.run_iter(idx, data_batch) File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\runner\loops.py", line 128, in run_iter outputs = self.runner.model.train_step( File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\model\base_model\base_model.py", line 116, in train_step optim_wrapper.update_params(parsed_losses) File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\optim\optimizer\optimizer_wrapper.py", line 200, in update_params self.backward(loss) File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\mmengine\optim\optimizer\amp_optimizer_wrapper.py", line 125, in backward self.loss_scaler.scale(loss).backward(**kwargs) File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\torch_tensor.py", line 487, in backward torch.autograd.backward( File "C:\Users\GBDJ.conda\envs\mmrotate\lib\site-packages\torch\autograd__init__.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2304, 5]] is at version 5; expected version 4 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Additional information

It seems that line 168 bbox_weights[:, -1] = 0. in file angle_branch_retina_head.py make these error. https://github.com/open-mmlab/mmrotate/blob/350099480693d9a11a60d00ab828fb2aee2d12c5/mmrotate/models/dense_heads/angle_branch_retina_head.py#L167 When setting shield_reg_angle = False, it's working usually.

gbdjxgp commented 1 year ago

this error still exists in the newest version of mmrotate. here is the full output while training the model. @yuyi1005 @liuyanyi

"D:\Program Files\Anaconda3\envs\ZDX\python.exe" F:/ZDX/mmrotate/tools/train.py configs/psc/rotated-retinanet-rbox-le90_r50_fpn_psc-dual_amp-1x_dota.py
08/31 16:58:45 - mmengine - INFO - 
------------------------------------------------------------
System environment:
    sys.platform: win32
    Python: 3.8.16 (default, Jun 12 2023, 21:00:42) [MSC v.1916 64 bit (AMD64)]
    CUDA available: True
    numpy_random_seed: 645004817
    GPU 0: NVIDIA GeForce RTX 2080 Ti
    CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3
    NVCC: Cuda compilation tools, release 11.3, V11.3.58
    GCC: n/a
    PyTorch: 1.12.1+cu116
    PyTorch compiling details: PyTorch built with:
  - C++ Version: 199711
  - MSVC 192829337
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 2019
  - LAPACK is enabled (usually provided by MKL)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.6
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.3.2  (built against CUDA 11.5)
  - Magma 2.5.4
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.6, CUDNN_VERSION=8.3.2, CXX_COMPILER=C:/actions-runner/_work/pytorch/pytorch/builder/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -openmp:experimental -IC:/actions-runner/_work/pytorch/pytorch/builder/windows/mkl/include -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_ROCM=OFF, 

    TorchVision: 0.13.1+cu116
    OpenCV: 4.8.0
    MMEngine: 0.8.4

Runtime environment:
    cudnn_benchmark: False
    mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}
    dist_cfg: {'backend': 'nccl'}
    seed: 645004817
    Distributed launcher: none
    Distributed training: False
    GPU number: 1
------------------------------------------------------------

08/31 16:58:45 - mmengine - INFO - Config:
angle_version = 'le90'
backend_args = None
data_root = 'data\\fair_test\\'
dataset_type = 'FAIR1MDataset'
default_hooks = dict(
    checkpoint=dict(interval=1, type='CheckpointHook'),
    logger=dict(interval=50, type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    timer=dict(type='IterTimerHook'),
    visualization=dict(type='mmdet.DetVisualizationHook'))
default_scope = 'mmrotate'
env_cfg = dict(
    cudnn_benchmark=False,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
launcher = 'none'
load_from = None
log_level = 'INFO'
log_processor = dict(by_epoch=True, type='LogProcessor', window_size=50)
model = dict(
    backbone=dict(
        depth=50,
        frozen_stages=1,
        init_cfg=dict(checkpoint='torchvision://resnet50', type='Pretrained'),
        norm_cfg=dict(requires_grad=True, type='BN'),
        norm_eval=True,
        num_stages=4,
        out_indices=(
            0,
            1,
            2,
            3,
        ),
        style='pytorch',
        type='mmdet.ResNet'),
    bbox_head=dict(
        anchor_generator=dict(
            angle_version=None,
            octave_base_scale=4,
            ratios=[
                1.0,
                0.5,
                2.0,
            ],
            scales_per_octave=3,
            strides=[
                8,
                16,
                32,
                64,
                128,
            ],
            type='FakeRotatedAnchorGenerator'),
        angle_coder=dict(
            angle_version='le90', dual_freq=True, num_step=3, type='PSCCoder'),
        bbox_coder=dict(
            angle_version='le90',
            edge_swap=True,
            norm_factor=None,
            proj_xy=True,
            target_means=(
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
            ),
            target_stds=(
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
            ),
            type='DeltaXYWHTRBBoxCoder'),
        feat_channels=256,
        in_channels=256,
        loss_angle=dict(loss_weight=0.2, type='mmdet.L1Loss'),
        loss_bbox=dict(loss_weight=0.5, type='mmdet.L1Loss'),
        loss_cls=dict(
            alpha=0.25,
            gamma=2.0,
            loss_weight=1.0,
            type='mmdet.FocalLoss',
            use_sigmoid=True),
        num_classes=15,
        shield_reg_angle=True,
        stacked_convs=4,
        type='AngleBranchRetinaHead',
        use_normalized_angle_feat=True),
    data_preprocessor=dict(
        bgr_to_rgb=True,
        boxtype2tensor=False,
        mean=[
            123.675,
            116.28,
            103.53,
        ],
        pad_size_divisor=32,
        std=[
            58.395,
            57.12,
            57.375,
        ],
        type='mmdet.DetDataPreprocessor'),
    neck=dict(
        add_extra_convs='on_input',
        in_channels=[
            256,
            512,
            1024,
            2048,
        ],
        num_outs=5,
        out_channels=256,
        start_level=1,
        type='mmdet.FPN'),
    test_cfg=dict(
        max_per_img=2000,
        min_bbox_size=0,
        nms=dict(iou_threshold=0.1, type='nms_rotated'),
        nms_pre=2000,
        score_thr=0.05),
    train_cfg=dict(
        allowed_border=-1,
        assigner=dict(
            ignore_iof_thr=-1,
            iou_calculator=dict(type='RBboxOverlaps2D'),
            min_pos_iou=0,
            neg_iou_thr=0.4,
            pos_iou_thr=0.5,
            type='mmdet.MaxIoUAssigner'),
        debug=False,
        pos_weight=-1,
        sampler=dict(type='mmdet.PseudoSampler')),
    type='mmdet.RetinaNet')
optim_wrapper = dict(
    clip_grad=dict(max_norm=35, norm_type=2),
    optimizer=dict(lr=0.0025, momentum=0.9, type='SGD', weight_decay=0.0001),
    type='AmpOptimWrapper')
param_scheduler = [
    dict(
        begin=0,
        by_epoch=False,
        end=500,
        start_factor=0.3333333333333333,
        type='LinearLR'),
    dict(
        begin=0,
        by_epoch=True,
        end=12,
        gamma=0.1,
        milestones=[
            8,
            11,
        ],
        type='MultiStepLR'),
]
resume = False
test_cfg = dict(type='TestLoop')
test_dataloader = dict(
    batch_size=1,
    dataset=dict(
        ann_file='validation/labelTxt/',
        data_prefix=dict(img_path='validation/images/'),
        data_root='data\\fair_test\\',
        pipeline=[
            dict(backend_args=None, type='mmdet.LoadImageFromFile'),
            dict(keep_ratio=True, scale=(
                1024,
                1024,
            ), type='mmdet.Resize'),
            dict(
                box_type='qbox', type='mmdet.LoadAnnotations', with_bbox=True),
            dict(
                box_type_mapping=dict(gt_bboxes='rbox'),
                type='ConvertBoxType'),
            dict(
                meta_keys=(
                    'img_id',
                    'img_path',
                    'ori_shape',
                    'img_shape',
                    'scale_factor',
                ),
                type='mmdet.PackDetInputs'),
        ],
        test_mode=True,
        type='FAIR1MDataset'),
    drop_last=False,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
test_evaluator = dict(metric='mAP', type='DOTAMetric')
test_pipeline = [
    dict(backend_args=None, type='mmdet.LoadImageFromFile'),
    dict(keep_ratio=True, scale=(
        1024,
        1024,
    ), type='mmdet.Resize'),
    dict(
        meta_keys=(
            'img_id',
            'img_path',
            'ori_shape',
            'img_shape',
            'scale_factor',
        ),
        type='mmdet.PackDetInputs'),
]
train_cfg = dict(max_epochs=12, type='EpochBasedTrainLoop', val_interval=1)
train_dataloader = dict(
    batch_sampler=None,
    batch_size=1,
    dataset=dict(
        ann_file='train/labelTxt/',
        data_prefix=dict(img_path='train/images/'),
        data_root='data\\fair_test\\',
        filter_cfg=dict(filter_empty_gt=True),
        pipeline=[
            dict(backend_args=None, type='mmdet.LoadImageFromFile'),
            dict(
                box_type='qbox', type='mmdet.LoadAnnotations', with_bbox=True),
            dict(
                box_type_mapping=dict(gt_bboxes='rbox'),
                type='ConvertBoxType'),
            dict(keep_ratio=True, scale=(
                1024,
                1024,
            ), type='mmdet.Resize'),
            dict(
                direction=[
                    'horizontal',
                    'vertical',
                    'diagonal',
                ],
                prob=0.75,
                type='mmdet.RandomFlip'),
            dict(type='mmdet.PackDetInputs'),
        ],
        type='FAIR1MDataset'),
    drop_last=False,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(shuffle=True, type='DefaultSampler'))
train_pipeline = [
    dict(backend_args=None, type='mmdet.LoadImageFromFile'),
    dict(box_type='qbox', type='mmdet.LoadAnnotations', with_bbox=True),
    dict(box_type_mapping=dict(gt_bboxes='rbox'), type='ConvertBoxType'),
    dict(keep_ratio=True, scale=(
        1024,
        1024,
    ), type='mmdet.Resize'),
    dict(
        direction=[
            'horizontal',
            'vertical',
            'diagonal',
        ],
        prob=0.75,
        type='mmdet.RandomFlip'),
    dict(type='mmdet.PackDetInputs'),
]
val_cfg = dict(type='ValLoop')
val_dataloader = dict(
    batch_size=1,
    dataset=dict(
        ann_file='validation/labelTxt/',
        data_prefix=dict(img_path='validation/images/'),
        data_root='data\\fair_test\\',
        pipeline=[
            dict(backend_args=None, type='mmdet.LoadImageFromFile'),
            dict(keep_ratio=True, scale=(
                1024,
                1024,
            ), type='mmdet.Resize'),
            dict(
                box_type='qbox', type='mmdet.LoadAnnotations', with_bbox=True),
            dict(
                box_type_mapping=dict(gt_bboxes='rbox'),
                type='ConvertBoxType'),
            dict(
                meta_keys=(
                    'img_id',
                    'img_path',
                    'ori_shape',
                    'img_shape',
                    'scale_factor',
                ),
                type='mmdet.PackDetInputs'),
        ],
        test_mode=True,
        type='FAIR1MDataset'),
    drop_last=False,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(metric='mAP', type='DOTAMetric')
val_pipeline = [
    dict(backend_args=None, type='mmdet.LoadImageFromFile'),
    dict(keep_ratio=True, scale=(
        1024,
        1024,
    ), type='mmdet.Resize'),
    dict(box_type='qbox', type='mmdet.LoadAnnotations', with_bbox=True),
    dict(box_type_mapping=dict(gt_bboxes='rbox'), type='ConvertBoxType'),
    dict(
        meta_keys=(
            'img_id',
            'img_path',
            'ori_shape',
            'img_shape',
            'scale_factor',
        ),
        type='mmdet.PackDetInputs'),
]
vis_backends = [
    dict(type='LocalVisBackend'),
]
visualizer = dict(
    name='visualizer',
    type='RotLocalVisualizer',
    vis_backends=[
        dict(type='LocalVisBackend'),
    ])
work_dir = './work_dirs\\rotated-retinanet-rbox-le90_r50_fpn_psc-dual_amp-1x_dota'

f:\zdx\mmdetection\mmdet\models\dense_heads\anchor_head.py:108: UserWarning: DeprecationWarning: `num_anchors` is deprecated, for consistency or also use `num_base_priors` instead
  warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
08/31 16:58:46 - mmengine - INFO - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
08/31 16:58:46 - mmengine - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) RuntimeInfoHook                    
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
before_train:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DistSamplerSeedHook                
 -------------------- 
before_train_iter:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
 -------------------- 
after_train_iter:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(BELOW_NORMAL) LoggerHook                         
(LOW         ) ParamSchedulerHook                 
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
after_train_epoch:
(NORMAL      ) IterTimerHook                      
(LOW         ) ParamSchedulerHook                 
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
before_val:
(VERY_HIGH   ) RuntimeInfoHook                    
 -------------------- 
before_val_epoch:
(NORMAL      ) IterTimerHook                      
 -------------------- 
before_val_iter:
(NORMAL      ) IterTimerHook                      
 -------------------- 
after_val_iter:
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DetVisualizationHook               
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
after_val_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(BELOW_NORMAL) LoggerHook                         
(LOW         ) ParamSchedulerHook                 
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
after_val:
(VERY_HIGH   ) RuntimeInfoHook                    
 -------------------- 
after_train:
(VERY_HIGH   ) RuntimeInfoHook                    
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
before_test:
(VERY_HIGH   ) RuntimeInfoHook                    
 -------------------- 
before_test_epoch:
(NORMAL      ) IterTimerHook                      
 -------------------- 
before_test_iter:
(NORMAL      ) IterTimerHook                      
 -------------------- 
after_test_iter:
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DetVisualizationHook               
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
after_test_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
after_test:
(VERY_HIGH   ) RuntimeInfoHook                    
 -------------------- 
after_run:
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
08/31 16:58:46 - mmengine - WARNING - Failed to search registry with scope "mmrotate" in the "optim_wrapper" registry tree. As a workaround, the current "optim_wrapper" registry in "mmengine" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmrotate" is a correct scope, or whether the registry is initialized.
08/31 16:58:47 - mmengine - INFO - load model from: torchvision://resnet50
08/31 16:58:47 - mmengine - INFO - Loads checkpoint by torchvision backend from path: torchvision://resnet50
08/31 16:58:47 - mmengine - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

08/31 16:58:47 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
08/31 16:58:47 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
08/31 16:58:47 - mmengine - INFO - Checkpoints will be saved to F:\ZDX\mmrotate\work_dirs\rotated-retinanet-rbox-le90_r50_fpn_psc-dual_amp-1x_dota.
F:\ZDX\mmrotate\mmrotate\structures\bbox\rotated_boxes.py:192: UserWarning: The `clip` function does nothing in `RotatedBoxes`.
  warnings.warn('The `clip` function does nothing in `RotatedBoxes`.')
Traceback (most recent call last):
  File "F:/ZDX/mmrotate/tools/train.py", line 125, in <module>
    main()
  File "F:/ZDX/mmrotate/tools/train.py", line 121, in main
    runner.train()
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\runner\runner.py", line 1745, in train
    model = self.train_loop.run()  # type: ignore
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\runner\loops.py", line 96, in run
    self.run_epoch()
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\runner\loops.py", line 112, in run_epoch
    self.run_iter(idx, data_batch)
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\runner\loops.py", line 128, in run_iter
    outputs = self.runner.model.train_step(
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\model\base_model\base_model.py", line 116, in train_step
    optim_wrapper.update_params(parsed_losses)
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\optim\optimizer\optimizer_wrapper.py", line 200, in update_params
    self.backward(loss)
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\mmengine\optim\optimizer\amp_optimizer_wrapper.py", line 125, in backward
    self.loss_scaler.scale(loss).backward(**kwargs)
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "D:\Program Files\Anaconda3\envs\ZDX\lib\site-packages\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2304, 5]] is at version 5; expected version 4 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Process finished with exit code 1