Closed qa6300525 closed 2 years ago
具体文件 parl/algorithms/torch/ddpg.py:67
具体代码
def _critic_learn(self, obs, action, reward, next_obs, terminal): # Compute the target Q value target_Q = self.target_model.value(next_obs, self.target_model.policy(next_obs)) target_Q = reward + ((1. - terminal) * self.gamma * target_Q).detach() # Get current Q estimate current_Q = self.model.value(obs, action) # Compute critic loss critic_loss = F.mse_loss(current_Q, target_Q) # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() return critic_loss
detach() 就是冻结网络的API,阻断了梯度的传播。
具体文件 parl/algorithms/torch/ddpg.py:67
具体代码