keiohta / tf2rl

TensorFlow2 Reinforcement Learning
MIT License
465 stars 103 forks source link

Possible error in critic update in SAC-AE algorithm #162

Open Cerphilly opened 2 years ago

Cerphilly commented 2 years ago

In SAC-AE algorithm, critic1 and 2 are updated as the following:

target_q = tf.stop_gradient(
rewards + not_dones * self.discount * (min_next_target_q - self.alpha * next_logps))

obs_features = self._encoder(obses, stop_q_grad=self._stop_q_grad)
current_q1 = self.qf1(obs_features, actions)
current_q2 = self.qf2(obs_features, actions)
td_loss_q1 = tf.reduce_mean((target_q - current_q1) ** 2)
td_loss_q2 = tf.reduce_mean((target_q - current_q2) ** 2)  # Eq.(6)

q1_grad = tape.gradient(td_loss_q1, self._encoder.trainable_variables + self.qf1.trainable_variables)
self.qf1_optimizer.apply_gradients(
zip(q1_grad, self._encoder.trainable_variables + self.qf1.trainable_variables))
q2_grad = tape.gradient(td_loss_q2, self._encoder.trainable_variables + self.qf2.trainable_variables)
self.qf2_optimizer.apply_gradients(
zip(q2_grad, self._encoder.trainable_variables + self.qf2.trainable_variables))

However, as encoder is optimized with q1 before q2 + encoder optimization, td_loss_q2 and q2_grad are inconsistent. Thus I believe q2_grad have to be calculated before optimizing qf1 and encoder.

keiohta commented 2 years ago

Hi @Cerphilly , thanks for pointing this out! I agree that this might be an error (not sure about the impact of this though). So, would you suggest something like the following?

q_grad = tape.gradient(td_loss_q1 + td_loss_q2, self._encoder.trainable_variables + self.qf1.trainable_variables + self.qf2.trainable_variables)
self.qf_optimizer.apply_gradients(
zip(q_grad, self._encoder.trainable_variables + self.qf1.trainable_variables + self.qf2.trainable_variables))

The above code just sums up the two TD losses and computes the gradients of it.

Cerphilly commented 2 years ago

Thanks for the quick response! I changed my code as the following:

            target_q = tf.stop_gradient(r + self.gamma * (1 - d) * (target_min_aq - self.alpha.numpy() * ns_logpi))
            with tf.GradientTape(persistent=True) as tape1:
                critic1_loss = tf.reduce_mean(tf.square(self.critic1(self.encoder(s), a) - target_q))
                critic2_loss = tf.reduce_mean(tf.square(self.critic2(self.encoder(s), a) - target_q))

            critic1_gradients = tape1.gradient(critic1_loss,
                                               self.encoder.trainable_variables + self.critic1.trainable_variables)

            critic2_gradients = tape1.gradient(critic2_loss,
                                               self.encoder.trainable_variables + self.critic2.trainable_variables)

            self.critic1_optimizer.apply_gradients(
                zip(critic1_gradients, self.encoder.trainable_variables + self.critic1.trainable_variables))

            self.critic2_optimizer.apply_gradients(
                zip(critic2_gradients, self.encoder.trainable_variables + self.critic2.trainable_variables))

and it seemed to achieve higher performance in RAD. But your suggestion also seems to work well without error.