p-christ / Deep-Reinforcement-Learning-Algorithms-with-PyTorch

PyTorch implementations of deep reinforcement learning algorithms and environments
MIT License
5.64k stars 1.2k forks source link

Sac Discrete Error #56

Closed sshillo closed 4 years ago

sshillo commented 4 years ago

Hi I'm trying to run SAC Discrete and I keep getting following error

Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "results/Cart_Pole.py", line 144, in <module>
    trainer.run_games_for_agents()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Trainer.py", line 79, in run_games_for_agents
    self.run_games_for_agent(agent_number + 1, agent_class)
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Trainer.py", line 117, in run_games_for_agent
    game_scores, rolling_scores, time_taken = agent.run_n_episodes()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Base_Agent.py", line 189, in run_n_episodes
    self.step()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC.py", line 87, in step
    self.learn()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC.py", line 147, in learn
    policy_loss, log_pi = self.calculate_actor_loss(state_batch)
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC_Discrete.py", line 87, in calculate_actor_loss
    qf2_pi = self.critic_local_2(state_batch)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/nn_builder/pytorch/NN.py", line 119, in forward
    out = self.process_output_layers(x)
  File "/usr/local/lib/python3.6/dist-packages/nn_builder/pytorch/NN.py", line 163, in process_output_layers
    temp_output = output_layer(x)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1610, in linear
    ret = torch.addmm(bias, input, weight.t())
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:60)
Traceback (most recent call last):
  File "results/Cart_Pole.py", line 144, in <module>
    trainer.run_games_for_agents()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Trainer.py", line 79, in run_games_for_agents
    self.run_games_for_agent(agent_number + 1, agent_class)
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Trainer.py", line 117, in run_games_for_agent
    game_scores, rolling_scores, time_taken = agent.run_n_episodes()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Base_Agent.py", line 189, in run_n_episodes
    self.step()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC.py", line 87, in step
    self.learn()
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC.py", line 150, in learn
    self.update_all_parameters(qf1_loss, qf2_loss, policy_loss, alpha_loss)
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC.py", line 192, in update_all_parameters
    self.hyperparameters["Actor"]["gradient_clipping_norm"])
  File "/home/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/Base_Agent.py", line 283, in take_optimisation_step
    loss.backward(retain_graph=retain_graph) #this calculates the gradients
  File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 100, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 2]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Any thoughts?

chenyu97 commented 4 years ago

The same problem!

yuchencena commented 4 years ago

The same problem :)

AlexTo commented 4 years ago

I think you may need to do this in calculate_actor_loss

with torch.no_grad():
    qf1_pi = self.critic_local(state_batch)
    qf2_pi = self.critic_local_2(state_batch)

Because we already propagate loss for critical_local and critic_local_2 in calculate_critic_losses so policy_loss will raise exception here? is it?

toshikwa commented 4 years ago

Hi.

The problem is due to the update of PyTorch (greater than 1.4.0 causes this problem.). I made PullRequest #60 to fix this issue. Cloneing from my branch or executing pip install torch==1.3.1 would solve your issue.

Thanks :)