Closed fuyw closed 2 years ago
I am confused about the cql temperature:
cql_concat_q1 = jnp.concatenate([ jnp.squeeze(cql_random_q1) - random_density, jnp.squeeze(cql_q1) - cql_logp, ]) cql_concat_q2 = jnp.concatenate([ jnp.squeeze(cql_random_q2) - random_density, jnp.squeeze(cql_q2) - cql_logp, ]) cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1) * self.config.cql_temp cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1) * self.config.cql_temp
Shouldn't it be:
cql_concat_q1 = jnp.concatenate([ jnp.squeeze(cql_random_q1) / self.config.cql_temp - random_density, jnp.squeeze(cql_q1) / self.config.cql_temp - cql_logp, ]) cql_concat_q2 = jnp.concatenate([ jnp.squeeze(cql_random_q2) / self.config.cql_temp - random_density, jnp.squeeze(cql_q2) / self.config.cql_temp - cql_logp, ]) cql_qf1_ood = torch.logsumexp(cql_cat_q1, dim=1) cql_qf2_ood = torch.logsumexp(cql_cat_q2, dim=1)
You are right that this is indeed a mistake. Fortunately we use cql_temp = 1 for all environments so this bug does not affect the algorithm in practice. I will fix it later.
I am confused about the cql temperature:
Shouldn't it be: