young-geng / CQL

Conservative Q Learning on top of SAC
MIT License
119 stars 25 forks source link

Question about the importance sampling #1

Closed dragon-wang closed 2 years ago

dragon-wang commented 3 years ago

In CQL paper's Appendix F, when using importance sampling to compute the log sum exp of Q(s,a) , only sample actions from Unif(a) and pi(a|s), but why here also need to sample actions from pi(a'|s'). This makes me confused.

cql_cat_q1 = torch.cat(
                    [cql_q1_rand - random_density,
                     cql_q1_next_actions - cql_next_log_pis.detach(),
                     cql_q1_current_actions - cql_current_log_pis.detach()],
                    dim=1
                )
cql_cat_q2 = torch.cat(
                    [cql_q2_rand - random_density,
                     cql_q2_next_actions - cql_next_log_pis.detach(),
                     cql_q2_current_actions - cql_current_log_pis.detach()],
                    dim=1
                )
cql_min_qf1_loss = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1).mean() * self.config.cql_min_q_weight * self.config.cql_temp
cql_min_qf2_loss = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1).mean() * self.config.cql_min_q_weight * self.config.cql_temp
Jiukaishi commented 3 years ago

That confuses me too. However in the paper author's implementation of CQL, he used it too. Can anyone explain this? Thanks!

glorgao commented 2 years ago

I think I could give some insight into this question.

CQL adds a conservative item, which is designed to minimize the Q-value of all valid actions. From this motivation, the sampling strategy for action values to minimize should be a uniform distribution. However, this may suffer from the un-efficiency issue.

\pi(a | s) and \pi(a' | s') could give action with high Q-values (true high values, or just high values induced from OOD actions or the so-called overestimates). Thus, these actions are of first priority to check.

I do not think the sampling stragety is important, and I think pi(a|s) should work, so as to pi(a'|s'). However, I have no evidence or supports for this conjecture.

fuyw commented 2 years ago

Yes, the sampling strategy is not important. After all, we aim to approximate the logsumexp using samples. \pi(a | s) and \pi(a' | s') are two distributions we have at hand. Only using uniform and \pi(a | s) is also okay.