tqjxlm / Simple-DQN-Pytorch

A simplistic implementation of DQN that works under CartPole-v0 with rendered pixels as input
13 stars 3 forks source link

Dueling - Average #1

Open msiegenthaler opened 5 years ago

msiegenthaler commented 5 years ago

Hi This is not really an issue, but since I'm just doing pretty much the same thing as you did in the repo while studying reinforced learning... I found a thing that I myself was not sure about but then decided is not correct:

In your dueling implementation you use return val + adv - val.mean(). I think by doing that your subtracting the average across all samples instead of doing it per sample. I did the same and the code still works and trains, but I think it should be torch.mean(adv, 1, keepdim=True) (or probably adv.mean(1, keepdim=True)). My network trains a bit better with the new approch, although it does not make that much of a difference.

I'd like you to thank you a lot for putting up this repo including the detailed training analysis, it helps me improve my own implementation and my knowledge. Great work!

tqjxlm commented 5 years ago

Thanks for the note. You are right. The subtracted value should be the mean of advantages, val.mean() is nonsense. It is a serious issue theoretically. I'm not sure why it does not matter here, there must be other limitations in the implementation. I'll fix it later.