araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

[Feature Request] Multi-Discrete action spaces for PPO #19

Closed tobiasmerkt closed 7 months ago

tobiasmerkt commented 11 months ago

🚀 Feature

Currently, PPO only supports (<class 'gymnasium.spaces.box.Box'>, <class 'gymnasium.spaces.discrete.Discrete'>) as action spaces. It would be awesome if it also supported MultiDiscrete action spaces.

Motivation

For many applications (Atari), one has to choose multiple discrete actions at each time step. StableBaselines3 supports MultiDiscrete action spaces already and it would be great if sbx supported it as well.

 Checklist

araffin commented 11 months ago

Hello, would you like to contribute this feature?

tobiasmerkt commented 10 months ago

Hi,

I’m sorry but I believe that my current skills are not sufficient to do this. I’m currently learning RL and started working with SB3 a week ago.

Cheers, Tobias

On 9. Nov 2023, at 15:58, Antonin RAFFIN @.***> wrote:

Hello, would you like to contribute this feature?

— Reply to this email directly, view it on GitHubhttps://github.com/araffin/sbx/issues/19#issuecomment-1803262711, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AJMBPTWLF7YMERLVBITIHPTYDR5KFAVCNFSM6AAAAAA7D2SXX2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMBTGI3DENZRGE. You are receiving this because you authored the thread.Message ID: @.***>

loafthecomputerphile commented 9 months ago

if i am reading the PPO source code correctly from SB3 and SBX the main thing we need to do is create a multi-categorical distribution class. the tensorflow_probability api that is used only supplies a categorical distribution which means we have to build our own or find a third party application. i had to go into my python installs to find the source code since Tensorflow didn't have it on their Github so i can see the source code. gladly from what i have seen it seems to be similar to the SB3 version meaning you need a few list comprehensions or for loops but i am not to sure how i would implement it and additionally i still need to learn how to use Github to commit, etc.

the distribution code can be found in the tensorflow_probability folder inside the distributions folder after you install SBX or tensorflow_probability . you will see the similarities to the Categorical distribution in SB3 (here) and the tensorflow_probability Categorical distribution code and it may be easy to apply the necessary conversions for the multi-categorical distributions in the same link posted previously. you may also have to edit the KL-divergence function also

i hope this comment is useful and can be used by others to help speed up the development of this feature and others if possible.