PaddlePaddle / PARL

A high-performance distributed training framework for Reinforcement Learning
https://parl.readthedocs.io/
Apache License 2.0
3.27k stars 822 forks source link

torch版本的ddpg更新critic网络的时候没有冻结target network,而paddle版本没有这个问题。 #866

Closed qa6300525 closed 2 years ago

qa6300525 commented 2 years ago

具体文件 parl/algorithms/torch/ddpg.py:67

具体代码

def _critic_learn(self, obs, action, reward, next_obs, terminal):
        # Compute the target Q value
        target_Q = self.target_model.value(next_obs,
                                           self.target_model.policy(next_obs))
        target_Q = reward + ((1. - terminal) * self.gamma * target_Q).detach()

        # Get current Q estimate
        current_Q = self.model.value(obs, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        return critic_loss
TomorrowIsAnOtherDay commented 2 years ago

detach() 就是冻结网络的API,阻断了梯度的传播。