instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
737 stars 90 forks source link

[FEATURE] ScannedRNN hidden state initialisation improvement #1058

Closed lbeyers closed 8 months ago

lbeyers commented 8 months ago

Please describe the purpose of the feature. Is it related to a problem?

Every time a hidden state is initialised with ScannedRNN.initialise_carry, a variable giving dimension for the layer width is required (sometimes this variable is called hidden_size or actor_network.pre_torso.layer_sizes[-1]). This is particularly a problem in the evaluator, where this variable cannot be consistently named between systems.

Describe the solution you'd like

When a ScannedRNN is initialised for use in a larger network, the hidden size must be recorded in such a way that it will be available whenever the hidden state must be reinitialised. I propose solutions below in order of preference:

How do we know when implementation of this feature is complete?

Checklist:

Additional context

This is a change that is too big to handle before the next deadline. It may affect all systems and their configs.