Denys88 / rl_games

RL implementations
MIT License
863 stars 146 forks source link

MultiDiscrete action_mask cannot run #264

Closed xialuanshi closed 10 months ago

xialuanshi commented 10 months ago

My program has two action output headers, which are defined as MultiDiscrete actions. The action dimensions are 5 and 3 respectively. When building the action_mask, whether it is merged together to become an 8-dimensional one, or divided into a two-dimensional list of 5 and 3, There are errors. In models.py, there is "CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)" which requires the dimension of action_masks to be consistent with the dimension of logits, but in a2c_discrete.py "action_masks = torch.BoolTensor For the conversion of (action_masks).to(self.ppo_device)", Tensor is an array and the dimensions of the data are required to be the same. So how should I give the data format of my action_mask?

Denys88 commented 10 months ago

Hi @xialuanshi, could you convert it to the regular discrete env with masks as the temporary workaround. In this case you will have 5x3 actions. And one of the good examples of the masked env is the: https://github.com/Denys88/rl_games/blob/a5d788a40b57c7470612ea98ae8da32754d9c760/rl_games/envs/smac_env.py#L98

I think I never had multi discrete masked envs and there might be bugs. I'll take a look deeper.

xialuanshi commented 10 months ago

Hi @xialuanshi, could you convert it to the regular discrete env with masks as the temporary workaround. In this case you will have 5x3 actions. And one of the good examples of the masked env is the:

https://github.com/Denys88/rl_games/blob/a5d788a40b57c7470612ea98ae8da32754d9c760/rl_games/envs/smac_env.py#L98

I think I never had multi discrete masked envs and there might be bugs. I'll take a look deeper.

Thank you very much for your reply. This is my reduced version of the action space. If the full version of the action space is multiplied, it will become very large.

Denys88 commented 10 months ago

Here is my branch with small fixes, looks like you need to add a few changes to the models.py and a2c_discrete.py https://github.com/Denys88/rl_games/compare/master...DM/sa_smac Could you try it and let me know if it works.