DeNA / HandyRL

HandyRL is a handy and simple framework based on Python and PyTorch for distributed reinforcement learning that is applicable to your own environments.
MIT License
282 stars 42 forks source link

Adding multidiscrete feature #328

Open Jogima-cyber opened 2 years ago

Jogima-cyber commented 2 years ago

Hello there ! I'd like to add the capability to handle multidiscrete action space. Here is how I would do that :

YuriCat commented 2 years ago

Hi, @Jogima-cyber ! Long time no see, and thanks for your great suggestion! I also tried the multiple action situation this year, so I'll check if it can be put in the main code.

Jogima-cyber commented 2 years ago

I'm currently working on it ! First thing I'm doing right now is integrating a home made simple multidiscrete environment into HandyRL in order to test my upcoming multidicrete integration. I'll push it as soon as I'm finished (it's gonna include a non-multidiscrete option). I'm not sur I'll need to add a multidiscrete variable into config, maybe just nvec being defined in model should be enough for the library to know it should handle a multidiscrete case. I need this feature right now, so I figured I could add it globally instead on just my fork, if you want it though.

YuriCat commented 2 years ago

@Jogima-cyber I pushed feature/multi_unit branch into my fork. https://github.com/DeNA/HandyRL/compare/master...YuriCat:HandyRL:feature/multi_unit

In this branch, each agent can output env.num_units() actions in each turn. I've confirmed that it works with TicTacToe!

Jogima-cyber commented 2 years ago

Thank you @YuriCat, that's almost exactly what I needed. I'm working on your fork, I just need to add a moment "unit_mask", because some of my independent action sets may be empty depending on the state (meaning no action can be taken among the actions of the action set).

YuriCat commented 2 years ago

@Jogima-cyber

unit_mask

Indeed! We need unit_mask if there are absent units.

Jogima-cyber commented 2 years ago

Btw. I've output the shapes of some vectors and in line feature/multi_unit/train.py:l204 I've got shapes that I don't think are normal :

YuriCat commented 2 years ago

@Jogima-cyber .unsqueeze(-1) may be necessary for either of the two.

Jogima-cyber commented 2 years ago

Yeah I found where to add a .unsqueeze(-1) : feature/multi_unit/train.py:l234 we have log_rhos = log_selected_t_policies.detach() - log_selected_b_policies but log_selected_t_policies is of shape (1, 32, 1, 2304, 1) but log_selected_b_policies is of shape (1, 32, 1, 1, 2304) and therefore torch is applying broadcasting. I've solved the issue by adding .unsqueeze(-1) and it should work for my case but I didn't take into account all the cases HandyRL is covering like RNN, turn based training and so on, so I don't think my fixing apply generally for HandyRL.

YuriCat commented 1 year ago

@Jogima-cyber Updated my branch feature/numti_unit.