facebookresearch / salina

a Lightweight library for sequential learning agents, including reinforcement learning
MIT License
426 stars 41 forks source link

Chunking Recurrent States and Truncated BPTT #35

Closed smorad closed 2 years ago

smorad commented 2 years ago

Hello,

I'm interested with loading and storing recurrent states for training over longer episodes. This is generally called truncated back propagation through time (BPTT). For example, in the following case we break each trajectory into 80-timestep chunks:

    env = AutoResetGymAgent(
        make_cartpole,
        n_envs=2,
    )
    actor = Agents(
        LSTMAgent(hidden_size=32),
        QNetworkAgent(input_handle="state", num_actions=2),
        EpsilonGreedyActorAgent(epsilon=0.02),
    )
    collector = TemporalAgent(Agents(env, actor))

    ws = Workspace()
    for epoch in range(10):
        collector(ws, t=0, n_steps=80)

Currently, if an episode is > 80 timesteps, it will receive a recurrent state of zeros. Does Salina provide a way to load the previous recurrent state?

ludc commented 2 years ago

Hi,

To make BPTT over chunck of N timesteps, you can do as shown in the A2C implementation:

https://github.com/facebookresearch/salina/blob/ac71cadacc54ae48377c67c88d269fc05209341c/salina_examples/rl/a2c/mono_cpu_2/main.py#L134-L140

Basically, the idea is: you sample N timesteps in the workspace, then you copy the last timestep at position 0, and you continue the sampling from timestep 1. It allows you to 'split' trajectories in pieces while not loosing the continuity in the acquisition. Each backward will thus backpropagate over the N last timesteps.

The copy_n_last_steps methods will do a copy of the n last timesteps in the workspace to the n first, and can be used with n>1 to generate sliding windows as it is done for instance in R2D2.

Does it answer your question ?

smorad commented 2 years ago

Yes, this is precisely it. Thanks!