aviralkumar2907 / BEAR

Code for Stabilizing Off-Policy RL via Bootstrapping Error Reduction
158 stars 39 forks source link

In place operations in algos.py #7

Open rrags opened 4 years ago

rrags commented 4 years ago

I keep getting this error due to some in place changes to the variable a in sample_multiple:

[W python_anomaly_mode.cpp:60] Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error: File "/home/7331215/wrappers/run_optimizer.py", line 211, in <module> main(sys.argv[1:]) File "/home/7331215/wrappers/run_optimizer.py", line 152, in main RewPred.generate_knobs() File "/home/7331215//wrappers/../../rewardpredictor/rewardpredictor_base.py", line 431, in generate_knobs self.generate_knobs_BEAR() File "/home/7331215//wrappers/../../rewardpredictor/rewardpredictor_base.py", line 532, in generate_knobs_BEAR pol_vals = policy.train(replay_buffer, iterations = int(5e3)) File "/home/7331215/wrappers/../../rl/Algos/BEAR/algos.py", line 440, in train actor_actions, raw_actor_actions = self.actor.sample_multiple(state, num_samples)# num) File "/home/7331215/../../rl/Algos/BEAR/algos.py", line 76, in sample_multiple log_std_a = self.log_std(a.clone()) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/nn/modules/linear.py", line 91, in forward return F.linear(input, self.weight, self.bias) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/nn/functional.py", line 1674, in linear ret = torch.addmm(bias, input, weight.t()) (function print_stack) Traceback (most recent call last): File "/home/7331215/wrappers/run_optimizer.py", line 211, in <module> main(sys.argv[1:]) File "/home/7331215/wrappers/run_optimizer.py", line 152, in main RewPred.generate_knobs() File "/home/7331215/wrappers/../../rewardpredictor/rewardpredictor_base.py", line 431, in generate_knobs self.generate_knobs_BEAR() File "/home/7331215/wrappers/../../rewardpredictor/rewardpredictor_base.py", line 532, in generate_knobs_BEAR pol_vals = policy.train(replay_buffer, iterations = int(5e3)) File "/home/7331215//wrappers/../../rl/Algos/BEAR/algos.py", line 508, in train (-lagrange_loss).backward() File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/tensor.py", line 185, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/autograd/__init__.py", line 127, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [300, 32]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Any guidance for how to fix? I have edited main.py to adapt to my specific problem task but haven't edited algos.py except to try to debug this error.

aviralkumar2907 commented 4 years ago

I am not sure if this is specific to this code, or coming from the optimizer and pytorch. I do not know how to fix this, but very likely there should be a solution somewhere related more generally to pytorch.

If it helps, you might want to run our new version of BEAR code (and install a new conda env): https://github.com/rail-berkeley/d4rl_evaluations

guyuwuyu commented 3 years ago

This issue comes from the version of PyTorch. It can be corrected in the latest version of pytorch by setting self.actor_optimizer.step() in algos.py after lagrange_loss = (-critic_qs +\ self._lambda * (np.sqrt((1 - self.delta_conf)/self.delta_conf)) * (std_q) +\ self.log_lagrange2.exp() * (mmd_loss - thresh)).mean().

jjd1123 commented 1 year ago

I think this problem may be caused by the modification of the parameter of the actor network. And self.actor_optimizer.step() will change the parameter of the actor network. So I think in order to solve the problem, we should recompute mmd_lossand critic_qs after self.actor_optimizer.step(). Or maybe we can just use mmd_loss.detach()and critic_qs.detach() when computing lagrange_loss, because we just need to update self.log_lagrange2. Or we can update self.log_lagrange2 before critic and actor.