aviralkumar2907 / CQL

Code for conservative Q-learning
408 stars 70 forks source link

Potential mismatch between math and code for CQL(rho) #8

Open zhihanyang2022 opened 3 years ago

zhihanyang2022 commented 3 years ago

This is a question regarding how CQL(rho) works in terms of code 😊.

In the CQL section (starting from line 235) within /CQL/d4rl/rlkit/torch/sac/cql.py, we first computed:

cat_q1 = torch.cat(
    [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
)
cat_q2 = torch.cat(
    [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
)

and then used them to compute

min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp

I'm a bit confused about why the Q values of actions drawn from three distinct distributions can be used to compute this quantity:

Here are my questions:

I'm able to completely understand how CQL(H) works in the codebase though.

loicsacre commented 3 years ago

I think they only gave the implementation of CQL(H). In their code base, the min_q_version is always set to 3, which corresponds to CQL(H). The equation with log-sum-exp is present in Appendix F (Additional Experimental Setup and Implementation Details).

qsa-fox commented 2 years ago

Equation 7 missed some item, i.e. the KL-divergence, after adding this item any you can deduce logsumexp