XiaohangZhan / deocclusion

Code for our CVPR 2020 work.
Apache License 2.0
794 stars 104 forks source link

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation #27

Closed umyta closed 4 years ago

umyta commented 4 years ago

Hi, could you give me some advice on this error. The details of the experiment is listed as follows:

  1. Dataset: COCOA
  2. environmtn: Python 3.7.9, pytorch 1.6.0
  3. Downloaded pretrains/partialconv.pth from here

I followed the instructions to run training. PCNet-M trains fine, and I did convert the partialconv.pth model to accept 4 channel inputs. When I run "sh experiments/COCOA/pcnet_c/train.sh", I got the following error:

*****************************************                                                                                                                                                                         
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.  
*****************************************                                                                                                                                                                         
main.py:14: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.                                     
  config = yaml.load(f)                                                                                                                                                                                           
main.py:14: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.                                     
  config = yaml.load(f)                                                                                                                                                                                           
main.py:14: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.                                     
  config = yaml.load(f)                                                                                                                                                                                           
main.py:14: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.                                     
  config = yaml.load(f)                                                                                                                                                                                           
=> loading checkpoint 'pretrains/partialconv_input_ch4.pth'                                                                                                                                                       
=> loading checkpoint 'pretrains/partialconv_input_ch4.pth'                                                                                                                                                       
=> loading checkpoint 'pretrains/partialconv_input_ch4.pth'                                                                                                                                                       
=> loading checkpoint 'pretrains/partialconv_input_ch4.pth'
[2020-09-22 15:53:59,916] Validation Iter: [0]  Time 0.443 (2.212)      Data 0.015 (1.491)      hole: 0.06159 (0.05562)  valid: 0.05347 (0.05307)        prc: 2.072 (2.004)      style: 0.01656 (0.01629)        $
v: 0.2303 (0.2479)      dis: 0 (0)      adv: 0 (0)
Traceback (most recent call last):
  File "main.py", line 48, in <module>
    main(args)
  File "main.py", line 30, in main
    trainer.run()
  File ".../deocclusion/trainer.py", line 125, in run
    self.train()
  File ".../deocclusion/trainer.py", line 147, in train
    loss_dict = self.model.step()
  File ".../deocclusion/models/partial_completion_content_cgan.py", line 153, in step
    gen_loss.backward()
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: enable a$
omaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Traceback (most recent call last):
  File "main.py", line 48, in <module>
Traceback (most recent call last):
  File "main.py", line 48, in <module>
    main(args)
  File "main.py", line 30, in main
    main(args)
  File "main.py", line 30, in main
    trainer.run()
File ".../deocclusion/trainer.py", line 125, in run                                                                                                   [7/1538]
    trainer.run()
  File ".../deocclusion/trainer.py", line 125, in run
    self.train()
  File ".../deocclusion/trainer.py", line 147, in train
    self.train()
  File ".../deocclusion/trainer.py", line 147, in train
    loss_dict = self.model.step()
  File ".../deocclusion/models/partial_completion_content_cgan.py", line 153, in step
    loss_dict = self.model.step()
  File ".../deocclusion/models/partial_completion_content_cgan.py", line 153, in step
    gen_loss.backward()
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    gen_loss.backward()
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 127, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: enable an
omaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: enable an
omaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Traceback (most recent call last):
  File "main.py", line 48, in <module>
    main(args)
  File "main.py", line 30, in main
    trainer.run()
  File ".../deocclusion/trainer.py", line 125, in run
    self.train()
  File ".../deocclusion/trainer.py", line 147, in train
    loss_dict = self.model.step()
  File ".../deocclusion/models/partial_completion_content_cgan.py", line 153, in step
    gen_loss.backward()
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: enable an
omaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Traceback (most recent call last):
  File ".../anaconda3/envs/python37/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File ".../anaconda3/envs/python37/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/distributed/launch.py", line 261, in <module>
    main()
  File ".../anaconda3/envs/python37/lib/python3.7/site-packages/torch/distributed/launch.py", line 257, in main
    cmd=cmd)
subprocess.CalledProcessError: Command '['.../anaconda3/envs/python37/bin/python', '-u', 'main.py', '--local_rank=3', '--config', 'experiments/COCOA/pcnet_c/config.yaml', '--launcher', 'pytor
ch', '--load-pretrain', 'pretrains/partialconv_input_ch4.pth']' returned non-zero exit status 1.

Has anyone run into this error before? Any help would be much appreciated. Thanks!

XiaohangZhan commented 4 years ago

I did not run into it before. You may try to degrade your pytorch version.

XiaohangZhan commented 4 years ago

The code was verified on pytorch0.4.1 and pytorch1.1

umyta commented 4 years ago

OK, let me try that. Thanks.

zhenghan408 commented 4 years ago

OK, let me try that. Thanks.

Have you solved it now? I have the same problem

zhenghan408 commented 4 years ago

OK, let me try that. Thanks.

i used pytorch1.1 ,But the same error occurred

umyta commented 4 years ago

@XiaohangZhan pytorch1.1 is working for me. Thanks! @zhenghan408 , I used python3.7 conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch.

bluestyle97 commented 4 years ago

I have solve this problem according to this link: https://discuss.pytorch.org/t/solved-pytorch1-5-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/90256.

What you need to do is modify models/partial_completion_content_cgan.py line 145-154. Please Rearrange it as:

# update
self.optimD.zero_grad()
self.optim.zero_grad()

dis_loss.backward()
gen_loss.backward()

utils.average_gradients(self.netD)
utils.average_gradients(self.model)

self.optimD.step()
self.optim.step()
XiaohangZhan commented 4 years ago

Thanks, could you please check if the training outputs keep similar before and after modification, and the performance does not hurt? If so, I will modify the code accordingly.