This PR adds support for stateful agents. Here is how it works:
2 new methods added to the agent: init_memory_batch and update_memory_batch.
init_memory_batch is called in the beginning of training to initialize the memory of an agent.
agent.update_memory_batch is called after every step in the environment and allows to update the current memory state or reinitialize it after the episode terminated.
Current memory state is passed to agent.act_on_batch, so that agent can use it
Memory state passed to agent.act_on_batch is also recorded in the trajectory so that it can be used during training steps.
Here's how memory is supported in MuZero:
When reanalyzing, we replay the trajectory chunk starting from the memory state recorded for the first timestamp of the chunk. This memory can be stale, but that doesn't seem to be a huge problem in practice.
Memory is aggregated along the trajectory via a transformer-based recurrent network and fused with the observation embedding.
There is a support for a mechanism that allows to update the initial memory state of the next trajectory chunk after processing some chunk, but no profit from it has been observed in the experiments. Actually it seems to make things worse, but I can't figure out why. Maybe there's a bug there, but I can't find it.
Representation function now takes just one observation and memory as an input, not the whole trajectory.
Other notable changes introduced while working on memory:
Replay buffers refactored to allow for id-based trajectory lookup.
A new type of replay buffer introduced that evicts trajectories based on trajectory age, not just on buffer capacity.
A bunch of new environments (memorytest*) to allow for testing memory-related functionality.
LR warmup can now be specified.
Observed glyphs can be cropped around any point, not just around the center.
backend naming changed in pytree lib: it's now "jax" and "numpy", not "gpu" and "cpu". The new naming is more correct.
IdentityHashWrapper removed as there's no need for it after replay buffer refactoring.
This PR adds support for stateful agents. Here is how it works:
Here's how memory is supported in MuZero:
Other notable changes introduced while working on memory: