aviralkumar2907 / CQL

Code for conservative Q-learning
405 stars 70 forks source link

QF_Loss backprops policy network #5

Open 00schen opened 3 years ago

00schen commented 3 years ago

In the CQL trainer, the policy_loss is formulated before the QF_Loss is, but the QF_Loss backprops the policy network before policy_loss does, which causes a Torch error. Would the intended use be to optimize policy network on the policy_loss before formulating the QF_Loss (and still optimize the policy using the QF_Loss) or to not reparametrize the policy output when formulating the QF_Loss (eg line 201)?

olliejday commented 3 years ago

Is this the error you are talking about? Because I have been trying to debug this too, can add full outputs if helpful.

/home/.../torch/autograd/__init__.py:132: UserWarning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256, 1]], which is output 0 of TBackward, is at version 40001; expected version 40000 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
python-BaseException
dosssman commented 3 years ago

In the CQL trainer, the policy_loss is formulated before the QF_Loss is, but the QF_Loss backprops the policy network before policy_loss does, which causes a Torch error.

I don't think the Qf_loss backprops the policy_loss, because they use different optimizer for the policy and the Q networks, respectively. It is hard to tell what is really causing that error without the stack trace also. Furthermore, what arguments are you using for the training ? (Because the script has a lot of parameterization: max-backup, n-qs, min-q-version, with_lagrange etc...., depending on the combination you use there might be an unforeseen computation that happens, thus causing the error).

In any case, have you tried to move the:

self._num_policy_update_steps += 1
self.policy_optimizer.zero_grad()
policy_loss.backward(retain_graph=False)
self.policy_optimizer.step()

just after the computation of the policy loss ? This way, the policy network would be optimized early on, and the subsequent operations involving the policy should not cause further error when used to compute the Q losses. (The problem might also be somewhere around automatic entropy tuning, which also uses the log_pi.)

olliejday commented 3 years ago

I'm currently testing this (small change) PR. It blocks the gradient flow to the q functions in the policy update which prevents the error.

dosssman commented 3 years ago

I am afraid that change will break the learning of the policy itself, because the q_new_actions.detach() in

policy_loss = (alpha*log_pi - q_new_actions.detach()).mean()

will also block the gradient flow to the policy, since q_new_actions is computed as below:

if self.num_qs == 1:
    q_new_actions = self.qf1(obs, new_obs_actions)
else:
    q_new_actions = torch.min(
          self.qf1(obs, new_obs_actions),
          self.qf2(obs, new_obs_actions),
    )

and the new_obs_actions are sampled use the re-parameterization trick (.rsample) in that self.policy() method.

new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
     obs, reparameterize=True, return_log_prob=True,
)

Therefore, the policy weights will only be updated to minimize the alpha*log_pi, but not actually maximize the action value q_new_actions.

(On a personal anecdote, I did the exact same thing when implementing SAC a while ago. It is critical not to detach the Q values when updating the policy. I think that is also the main reason the optimizers are separated for the policy and the Q networks: so that the action value of the policy can be backproped through the Q functions, without altering the weights of the latter.)

aviralkumar2907 commented 3 years ago

@olliejday I think the error is caused due to pytorch version. If you try like torch 1.4 that could fix it. Something more might break it. Could you please confirm if this is the issue or not?

aviralkumar2907 commented 3 years ago

@olliejday @dosssman The Q-function detach will not work, since then the policy is not trained using the Q-function which is incorrect.

olliejday commented 3 years ago

I'm looking at what @dosssman says but in reverse (ie. moving the q functions rather than the policy update)

So moving

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha*log_pi - q_new_actions).mean()

To just before the policy update but after the Q updates, ie. directly above here:

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()

This stops the error and seems to match the order in the paper

image

I'm trying it now otherwise will test which torch versions work.

Thanks

dssrgu commented 3 years ago

Hello,

Any updates on this issue? I tried using @olliejday's solution, but the results are different in some environments. For example, on hopper-expert-v0 with policy_lr=1e-4, min_q_weight=5.0, and langrange_thresh=-1.0, the average return results are:

Unmodified code w/ torch=1.4 : 3638.71 Modified code w/ torch=1.7 : 3.08

olliejday commented 3 years ago

Hi, I ended up just reverting torch versions to 1.4

dssrgu commented 3 years ago

Adding .detach() to the outputs of _get_policy_actions() and switching the update order of the policy network and the q-function networks seem to solve the issue (Tested in torch=1.7).

sweetice commented 3 years ago

@dssrgu Thanks for your contribution!, I have adapted your commits. But here are some questions. Does your modification works? Have you compared with the original version (PyTorch version==1.4, work)? If yes, which one is better?

dssrgu commented 3 years ago

@sweetice Hi, I did test the modified version against the original version (which was ran on torch == 1.4), and the two versions had similar performances on d4rl datasets. I do not have the actual values right now though.

Note: You may have to additionally correct the retain_graph parameters on the backward step according to change of the update order.

Zhendong-Wang commented 3 years ago

@dssrgu Did you get similar results to the values reported in D4RL paper?

I both tried the paper hyperparameters(policy_lr = 3e-5, lagrange_thresh=10.0) and the recommended one in this github (policy_lr = 1e-4, lagrange_thresh=-1.0) in Pytorch 1.4 and 1.7+, but I can not obtain similar values in some environemnts, for example, there is a big difference in 'halfcheetah-medium-expert-v0', and huge difference in Adroit task, like 'pen-human', 'hammer-human' and 'door-human'.

Do you know how to set the hyperparemeters to make CQL work in most cases? Thanks!

dssrgu commented 3 years ago

@Zhendong-Wang I found policy_lr=1e-4, min_q_weight=10.0, lagrange_thresh=-1.0 to work fairly well on most of the gym environments, though I used '*-v2' datasets. Exceptionally, for 'halfcheetah-random-v2', policy_lr=1e-4, min_q_weight=1.0, lagrange_thresh=10.0 works well. If the problem is only the medium-expert datasets, it seems the algorithm needs to run 3000 epochs to converge.

For Adroit task, I also could not reproduce the results...

glorgao commented 3 years ago

@dssrgu Could you give me some advice?

I use the hyparameter you recommended, and the results in 'medium' envs are keep in line with the CQL paper results. However, the results for 'walker2d-expert-v2' task cannot improve after achieving 5000, and the papers are about 7000. The result line seems be limited to 5000 as the curve is so straight after reaching 5000.

I believe there must be something wrong in my settings, which are: mujoco200 pytroch=1.4 or 1.1 d4rl=1.1 for walker2d-expert-v2, walker2d-expert-v0 tasks

Do you have any suggestions for me?

Zhendong-Wang commented 3 years ago

@cangcn Actually, with the github code and the hyperparameters recommended in Readme file, I can not reproduce the reported results in D4RL paper, even in Gym tasks. I tried both 'v2' and 'v0'. The performance on 'v2' is generally better than 'v0'. It still can not match most of the results reported, though it is mentioned in D4RL they used 'v0' for fair comparison .

jihwan-jeong commented 3 years ago

Hi.

@olliejday: I think this issue shouldn't be resolved by just resorting to switching back to torch versions below 1.5 (i.e. <=1.4), because then the reproducibility relies on the bug in the torch code (see this thread). According to the linked discussion, in torch < 1.5, even when the code runs and trains network parameters, the computed gradients can be incorrect, which is fixed in torch >=1.5.

Hopefully, the PR that @dssrgu posted can solve this issue, but for some tasks, it seems the results cannot be reproduced.. I hope the original author @aviralkumar2907 can provide some feedback on this matter :) In the meantime, I think I'll use @dssrgu's modifications to make the code runnable.

Thanks!