inarikami / keras-rl2

Reinforcement learning with tensorflow 2 keras
MIT License
251 stars 105 forks source link

DDPGAgent is incompatible with MultiInputProcessor for HandReach-v0 env #30

Open alexaorrico opened 3 years ago

alexaorrico commented 3 years ago

DDPGAgent fails to train on the critic model while using a MultiInputProcessor within its backward method, specifically at lines 260-263:

                if len(self.critic.inputs) >= 3:
                    state1_batch_with_action = state1_batch[:]
                else:
                    state1_batch_with_action = [state1_batch]
                state1_batch_with_action.insert(self.critic_action_input_idx, target_actions)

This throws the error TypeError: unhashable type: 'slice' since state1_batch is a dictionary with three keys, as returned from the processor. It seems that this chunk of code automatically assumes that state1_batch will be a list instead of a dictionary. The same can be said a few lines down with state0_batch. I would love to be able to fix this myself, but am unsure why there was a hardcoded 3 in the logic or why the length of the inputs would make a difference. I'd love to understand if someone is willing to explain.

Here is the script: hand_reach.py

Please make sure that the boxes below are checked before you submit your issue. If your issue is an implementation question, please ask your question in the Discord.

Thank you!