toshikwa / sac-discrete.pytorch

PyTorch implementation of SAC-Discrete.
MIT License
284 stars 35 forks source link

SAC #7

Closed NikEyX closed 4 years ago

NikEyX commented 4 years ago

Hi again :)

I checked your code briefly, it seems you fixed the .mean() for the q_value loss but not for the policy loss. I raised this issue on https://www.github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/issues/54#issuecomment-634398732 and provided more details.

I would be interested to see your results on Pong/Pacman/etc if you use my above suggested changes, also with regards to the target entropy. I would imagine it should solve Pong pretty convincingly?

toshikwa commented 4 years ago

Hi, @NikEyX

I think policy_loss is calculated correctly. mean() is calculated over the batch, not actions, and sum() is properly calculated over actions.

https://github.com/ku2482/sac-discrete.pytorch/blob/master/sacd/agent/sacd.py#L125

So it seems correct, doesn't it?

Anyway, thank you for asking !!

NikEyX commented 4 years ago

oh, you're right! yes you did do it correctly then (or at least the same way as I suggested in the linked post). Somehow I misinterpreted your "weights" as action_probabilities, but those seem to be the importance sample weights from the priority sampling. That makes sense. Sorry about that.