Open ofirnabati opened 4 years ago
Hi,
From what I understand your note, the behaviour you described was intentional and your change will be the reason for the lower performance and slower convergence, because now you take fewer gradients steps for each data-point collected (leading to slower convergence) and importantly, you don't take gradient steps in some states at all, explaining the much lower performance. With default configuration, you now only take gradient steps in 1/5th of the states.
If you wanted to eliminate the bias, you could set multiplier_backprop_length=1
, but as discussed in the paper, the smaller backprop length would lead to worse results. To counter that you could increase num_steps
, which will probably increase the results again.
Training RNNs in RL is tricky because ideally, we would backpropagate all the way back to the beginning of the episode, which, usually, is infeasible. At the time of publishing DVRL, the state of the art was just to backpropagate for a fixed number of steps (E.g. Recurrent-DQN). The implementation in my code tried to improve upon that, but as you've noted, at the expense of introducing some bias in the gradients. Since then a lot of other (probably much better) ideas have been proposed to overcome this problem of training RNNs in RL.
Did you change the gradient updates because it was crashing the program or because you were worried it would lead to a biased result? If it's crashing the program, then that's probably because newer versions of pytorch might not support the original behavior anymore, at least not the way it was implemented.
Hi,
Thanks for the quick response!
Yes, you are right, the convergence is 5 times slower with my modifications. But I do still accumulate the gradients and then perform the gradient step (but not performing optimizer.step() between each iteration as you do).
I use a newer version of PyTorch that crashes due to the modifications of weights between updates, which as you noted cause a "wrong" gradient calculation/ bias.
So as I see it, I can use the old PyTorch version (0.4) that doesn't care about this bias but it doesn't seem right.. do you have any suggestions to improve that so it will work on a newer version of PyTorch without bias? The naive solution I guess would be to run it 5 times longer...
Thanks
So you not only retain the graph, but you also don't call optimizer.zero_grad()
at every update step?
If you don't want to use the current code, it's tricky and I'm not sure how to best trade-off the back-propagation length, sample-efficiency and the bias/variance-tradeoff of the n-step estimator. I know that for value based agents, sota is something like R2D2, but I don't know what people currently do for AC methods.
Yes. zero_grad() is called only when retain_graph is False.
# Only reset the computation graph every 'multiplier_backprop_length' iterations
retain_graph = j % algorithm['multiplier_backprop_length'] != 0
total_loss.backward(retain_graph=retain_graph)
if not retain_graph:
if opt['max_grad_norm'] > 0:
nn.utils.clip_grad_norm_(actor_critic.parameters(), opt['max_grad_norm'])
optimizer.step()
optimizer.zero_grad()
current_memory['states'] = current_memory['states'].detach()
Do you have a clue how I can perform it on a new version PyTroch where weight modifications if not allowed during BPTT?
Unfortunately not, sorry, I haven't worked with the codebase in a while. If I find something I'll let you know, but not sure when I'll have time to look into it.
Sure. Thanks.
Hi,
I'm trying to run your code in DVRL mode (according to the configurations you mention in the README file) and the results are significantly lower compared to the ones published in your paper (also, convergence is much slower).
Important note: I needed to modify your code because of a bug- you are doing a gradient step in the case of BPTT. This alters the weights while still backprop through the last iterations which yield an error. In order to fix that I perform the gradient step (optimizer.step()) only when retain_graph=False.
In case this is the cause of the problem of low performance, what should I do?
Thanks