Closed dosssman closed 3 years ago
I think this is OK actually.
Perhaps confusingly named "q1_next_actions" but it seems to be the resampled actions from this part (right hand side) of the estimation for the log sum exp term (from appendix F in the paper):
So it should be current state with new actions.
Greetings.
I would say that the
the resampled actions from this part (right hand side
you have mentioned instead correspond to the following:
curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
# skipped lines
q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
So it should be current state with new actions.
That would correspond to the lines highlighted above I think.
Also, from the equation in your comment, the log exp sum is computed over multiple actions {a_i }, i \in {1.. num_actions}
sampled for a specific state s
.
Therefore, if we were to rigorously follow that same equation, if we compute the new_curr_actions_tensor
using the next_obs
, the log sum exp should also be taken with respect to those next_obs
, I think.
Nevertheless, it would still "work" since the goal of the CQL objective is to minimize the the Q values for know states but "out of distributions" actions. Namely, new_curr_actions_tensor
would indeed be "out of distribution" with respect to states next_obs
.
Sorry for the late reply. It is mathematically correct, since it is just a third term for passing action samples for computing the logsumexp. In this code version, the log-sum-exp is computed using there terms:
Thanks for the answer.
Greetings.
Thank you for your amazing work on Offline RL, as well as for open-sourcing the code.
This present issue pertains to the computation for the lower bounding component of the SAC CQL:
Namely, at line 236, , the actions
new_curr_actions_tensor
of the policy for the next states in the batch,next_obs
, are computed by feeding the latter to the policy.https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/sac/cql.py#L236
When computing the corresponding Q value, however, the
next_curr_actions_tensor
are fed to the Q networks with what seems to be the observations at the current time stepobs
:https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/sac/cql.py#L241
Shouldn't it be
next_obs
instead ofobs
at those two lines 241 and 242? Or is there a specific reason we might want to use actions of the next states to compute the Q value for the current observations batch (states) ? (Sampling "incorrect" actions with regard to the current observations (states) on purpose ?)Thank you for your time, and sorry for the inconvenience.