young-geng / CQL

Conservative Q Learning on top of SAC
MIT License
119 stars 25 forks source link

Question about the CQL-temperature #3

Closed fuyw closed 2 years ago

fuyw commented 2 years ago

I am confused about the cql temperature:

cql_concat_q1 = jnp.concatenate([
      jnp.squeeze(cql_random_q1) - random_density,
      jnp.squeeze(cql_q1) - cql_logp,
])
cql_concat_q2 = jnp.concatenate([
      jnp.squeeze(cql_random_q2) - random_density,
      jnp.squeeze(cql_q2) - cql_logp,
])
cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1) * self.config.cql_temp
cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1) * self.config.cql_temp

Shouldn't it be:

cql_concat_q1 = jnp.concatenate([
      jnp.squeeze(cql_random_q1) / self.config.cql_temp - random_density,
      jnp.squeeze(cql_q1) / self.config.cql_temp - cql_logp,
])
cql_concat_q2 = jnp.concatenate([
      jnp.squeeze(cql_random_q2) / self.config.cql_temp - random_density,
      jnp.squeeze(cql_q2) / self.config.cql_temp - cql_logp,
])
cql_qf1_ood = torch.logsumexp(cql_cat_q1, dim=1)
cql_qf2_ood = torch.logsumexp(cql_cat_q2, dim=1)

image

young-geng commented 2 years ago

You are right that this is indeed a mistake. Fortunately we use cql_temp = 1 for all environments so this bug does not affect the algorithm in practice. I will fix it later.