aviralkumar2907 / CQL

Code for conservative Q-learning
408 stars 70 forks source link

SAC CQL: Potential mismatch between observations and actions fed to the Q network in CQL computation. #4

Closed dosssman closed 3 years ago

dosssman commented 3 years ago

Greetings.

Thank you for your amazing work on Offline RL, as well as for open-sourcing the code.

This present issue pertains to the computation for the lower bounding component of the SAC CQL:

        ## add CQL
        random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda()
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy)
        q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
        q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)
        q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)

Namely, at line 236, , the actions new_curr_actions_tensor of the policy for the next states in the batch, next_obs, are computed by feeding the latter to the policy.

https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/sac/cql.py#L236

When computing the corresponding Q value, however, the next_curr_actions_tensor are fed to the Q networks with what seems to be the observations at the current time step obs:

https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/sac/cql.py#L241

Shouldn't it be next_obs instead of obs at those two lines 241 and 242? Or is there a specific reason we might want to use actions of the next states to compute the Q value for the current observations batch (states) ? (Sampling "incorrect" actions with regard to the current observations (states) on purpose ?)

Thank you for your time, and sorry for the inconvenience.

olliejday commented 3 years ago

I think this is OK actually.

Perhaps confusingly named "q1_next_actions" but it seems to be the resampled actions from this part (right hand side) of the estimation for the log sum exp term (from appendix F in the paper):

image

So it should be current state with new actions.

dosssman commented 3 years ago

Greetings.

I would say that the

the resampled actions from this part (right hand side

you have mentioned instead correspond to the following:

curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
# skipped lines
q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)

So it should be current state with new actions.

That would correspond to the lines highlighted above I think.

Also, from the equation in your comment, the log exp sum is computed over multiple actions {a_i }, i \in {1.. num_actions} sampled for a specific state s. Therefore, if we were to rigorously follow that same equation, if we compute the new_curr_actions_tensor using the next_obs, the log sum exp should also be taken with respect to those next_obs, I think.

Nevertheless, it would still "work" since the goal of the CQL objective is to minimize the the Q values for know states but "out of distributions" actions. Namely, new_curr_actions_tensor would indeed be "out of distribution" with respect to states next_obs.

aviralkumar2907 commented 3 years ago

Sorry for the late reply. It is mathematically correct, since it is just a third term for passing action samples for computing the logsumexp. In this code version, the log-sum-exp is computed using there terms:

  1. Actions from the current policy. This is what I guess is clear.
  2. Uniform actions.
  3. Actions from the policy at the next state. Note that we can still use these next actions with the state since these are just action samples given to us. If we know the probabilities from which these actions are sampled, which is \pi(next_actions|next_obs), then the Q-function term should be Q(curr_state, next_action) - \log \pi(next_action|next_obs), where this is fine, since we sampled next actions from the policy at the next state but we are using these action samples to compute the log-sum-exp of the Q-function.
dosssman commented 3 years ago

Thanks for the answer.