jianda-chen / RAP_distance

Other
3 stars 0 forks source link

Critic target network in RAP Loss Implementation #6

Open Zengxia-Guo opened 1 month ago

Zengxia-Guo commented 1 month ago

Hi,

Thank you for sharing your valuable work! I have been reading your paper and reviewing the code implementation, and I noticed a potential inconsistency. Specifically, the paper mentions using the target encoder network $overline{\omega}$ in the definition of the RAP loss: target_encoder

However, in the code, I noticed that the update_encoder method uses critic.encoder instead of critic_target.encoder. Here is the relevant code snippet:

def update_encoder(self, obs, action, behavioural_log_pi, reward, next_obs, L, step):
    h = self.critic.encoder(obs)
    # Sample random states across episodes at random
    batch_size = obs.size(0)
    perm = np.random.permutation(batch_size)
    h2 = h[perm]

    with torch.no_grad():
        pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([h, action], dim=1))

I would like to confirm the following points:

According to the paper, the RAP loss should be based on the target encoder network $overline{\omega}$ . Should the code be updated to use critic_target.encoder instead of critic.encoder?

Why do we use a target encoder network and what are its benefits?

Thank you very much for your help !

Zengxia-Guo commented 1 month ago

For example,

   def update_encoder(self, obs, action, behavioural_log_pi, reward, next_obs, L, step):
        h = self.critic.encoder(obs)
        target_h = self.critic_target.encoder(obs)

        # Sample random states across episodes at random
        batch_size = obs.size(0)
        perm = np.random.permutation(batch_size)
        h2 = h[perm]

        with torch.no_grad():
            pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([target_h, action], dim=1))
jianda-chen commented 1 month ago

Please follow the source codes to use online h. The $\phi_{\bar{\omega}}$ in loss (15) should mean stop gradient only. Target encoder only computes target Q-values.