Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
462 stars 169 forks source link

question regarding qrdqn #17

Closed guyk1971 closed 3 years ago

guyk1971 commented 3 years ago

Hi, while trying to extend qrdqn to support double qrdqn, I couldn't convince myself with the current implementation of the update function.

specifically, refer to the train function in QRDQN class (qrdqn.py), line 165:

                # Follow greedy policy: use the one with the highest value
                next_quantiles, _ = next_quantiles.max(dim=2)

it looks like the max operation is done on the actions dimensions and by that per each quantile separately, the value from the action that yielded maximum value. (e.g. for the first quantile it can take from a=1, for the 2nd quantile it can take from a=4 etc.) so the resulting buffer will include quantile values from different actions.

As I have understood (and checked various other implementations of qrdqn), the intent was to calculate the q value of each of the actions by averaging over the quantiles of each action and only then take the quantiles of the best next action. something like this:

            with th.no_grad():
                # Compute the quantiles of next observation
                next_quantiles = self.quantile_net_target(replay_data.next_observations)
                best_next_actions = next_quantiles.mean(dim=1).argmax(dim=1,keepdim=True)
                actions_index = best_next_actions[...,None].expand(batch_size,self.n_quantiles,1)
                next_quantiles = target_quantiles.gather(dim=2,index=actions_index).squeeze(dim=2)
                # 1-step TD target
                target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

My question : is the current implementation differs by design ? if so, and assuming the implementation that I'm proposing here corresponds to the publication, what's the justification of the variant currently implemented ?

BTW, Great work with SBL3 !! a very useful and easy to work with library ! I mainly use it (and thus extend it) for offline RL settings.

Thanks a lot :)

Miffyli commented 3 years ago

Pinging @ku2482 to give insight. We maintainers are mainly taking vacations and won't be around until after holidays ^^

BTW, Great work with SBL3 !! a very useful and easy to work with library !

Thanks for the kind words :). This work would not be possible without such acknowledgements, and more importantly, not without the generous contributions done by many!

araffin commented 3 years ago

is the current implementation differs by design ?

Hello, thanks for raising this issue. In fact, I forgot to ask @ku2482 myself why is it so... That may be a bug. I will try to check again the paper and other implementations, but I would tend to agree with you ;)

toshikwa commented 3 years ago

@guyk1971 Thank you for pointing it out. You're right it's a bug.

@Miffyli @araffin I'm so sorry about that... I will create PR to fix it later today, and I will also evaluate the algorithm again.

Thanks.

guyk1971 commented 3 years ago

Thanks for handling this, and for the great work you're doing with stable baselines 3 ! 👏

Good luck ! 🤞

On Wed, Jan 6, 2021 at 2:15 AM Toshiki Watanabe notifications@github.com wrote:

@guyk1971 https://github.com/guyk1971 Thank you for pointing it out. You're right it's a bug.

@Miffyli https://github.com/Miffyli @araffin https://github.com/araffin I'm so sorry about that... I will create PR to fix it later today, and I will also evaluate the algorithm again.

Thanks.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/17#issuecomment-754979342, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYIHGHQKV6D4UVMEKIX6F3SYOTT7ANCNFSM4VJFLNZQ .