openai / spinningup

An educational resource to help anyone learn deep reinforcement learning.
https://spinningup.openai.com/
MIT License
10.18k stars 2.23k forks source link

serious BUG in sac pytorch implementation #329

Open zlw21gxy opened 3 years ago

zlw21gxy commented 3 years ago

SAC algorithm in PyTorch implementation has a serious bug

q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

itertools.chain will become empty after the first iteration, so every time you call q_params, you just use an empty iterator.

q_optimizer = Adam(q_params, lr=lr)
print([x for x in q_params]) 

actually after initialized the q_optimizer, q_params becomes empty, but we call it multiple times to set gradients

for p in q_params:
    p.requires_grad = False

a quick fix is using

q_params = [x for x in ac.q1.parameters()] + [x for x in ac.q2.parameters()]
twni2016 commented 3 years ago

Yes, I also found this bug. But this will only affect the speed of backprop, not changing the policy gradient.

Roadsong commented 3 years ago

@twni2016 Hi twni2016, you mentioned that this will only affect the speed of backdrop, do you have any idea to resolve the slow performance issue? I also found that the SAC implementation is extremely slow...

twni2016 commented 3 years ago

@Roadsong I don't think the spinnup is very slow (maybe you have other issues). You can just follow @zlw21gxy's way to fix this issue.

yf291115925 commented 2 years ago

Another quick fix is using q_params = list(itertools.chain(ac.q1.parameters(), ac.q2.parameters()))

Alberto-Hache commented 2 years ago

This seems to be the same issue solved by two pending PRs: For TD3: "Fix problem with empty iterator" #330 For SAC: "Fix Q-networks freezing in PyTorch SAC" #251 q_params = list(itertools.chain(ac.q1.parameters(), ac.q2.parameters()))

And it seems to be addressed too with a lambda function in this other pending PR: "Fixes sac critic grad freeze bug" #320:

    q_params = lambda: itertools.chain(
        *[gen() for gen in [ac.q1.parameters, ac.q2.parameters]]
    )

(I must say that I have NOT tried them though, as I'm now working with PPO.)