alex-petrenko / sample-factory

High throughput synchronous and asynchronous reinforcement learning
https://samplefactory.dev
MIT License
802 stars 109 forks source link

Recurrent policy details #108

Closed MarcoMeter closed 3 years ago

MarcoMeter commented 3 years ago

Hi @alex-petrenko !

I've got a brief question on your recurrent policy implementation. Are you feeding complete episode trajectories to the learner's RNN?

That's what it looks like to me, because I'm not finding any padding or a distinct sequence length to do truncated bptt. All I found was the PackedSequence object of PyTorch.

alex-petrenko commented 3 years ago

Hi @MarcoMeter !

We're feeding a single rollout to a recurrent policy. This is controlled by two CLI parameters: --rollout and --recurrence. Typically they are the same, i.e. when --rollout=32 and --recurrence=32 we collect 32-timestep sequences on the actors and send them to the learner, then the entire 32-timestep sequence is being processed by the RNN policy. We do not feed the entire episode to the RNN policy.

Please run Sample Factory train script with --help argument to get detailed information about these parameters (or look in the code).

Obviously, if the episode termination flag occurs in the middle of the rollout, we split it in two sequences to be fed into an RNN policy separately. This is done via PackedSequence, it's a PyTorch way to accelerate RNNs on unequal-length sequences.

You can check the shape of head_outputs and mb.rnn_states to convince yourself that this is the case :)

Please let me know if this is useful, I'm happy to help!

github-actions[bot] commented 3 years ago

This issue is stale because it has been open for 30 days with no activity.