Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
442 stars 166 forks source link

[Question] how to use "lstm_states" from rollout_buffer to reconstruct LSTM states during training #222

Closed DeepRowLie closed 5 months ago

DeepRowLie commented 6 months ago

❓ Question

Hi all! I hope to integrate RNN(LSTM/GRU) to off-policy algorithm(SAC and TD3) without multiprocessing like A3C.So I checked SB3-contrib code about recurrentPPO and the recurrentPPO document you recommended. In SB3-contrib, recurrentPPO puts _lstmstates into _rolloutbuffer when collect transitions.During Training,the sequences not start from the beginning of one episode will use the out of style _lstmstates to reconstruct LSTM states.I'm confused about this. https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/9f333ffc34280fd16f438ff303f7b3f7792b0068/sb3_contrib/ppo_recurrent/ppo_recurrent.py#L351C1-L356C18 https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/9f333ffc34280fd16f438ff303f7b3f7792b0068/sb3_contrib/common/recurrent/policies.py#L198C9-L207C36

Here are my questions:

  1. why should we use the outdated '_lstmstates' to reconstruct LSTM states instead of just initializing them,when the sequence doesnt start from the begining?After several epoches, it seems that _lstmstates stored in _rolloutbuffer may not accurately reflect historical information since the network has evolved.Does the Trust Region algothrim make the updated networks similiar enough, so we can use the old states?
  2. when appling RNN to SAC's replay buffer,how should LSTM states be reconsturcted?Should design the squence always start from the begining of each episode?

Checklist

araffin commented 5 months ago

Hello, SAC with RNN and PPO with RNN will be quite different because PPO is on-policy (so the data collected is discarded after one update).

why should we use the outdated 'lstm_states' to reconstruct LSTM states instead of just initializing them,when the sequence doesnt start from the begining? when appling RNN to SAC's replay buffer,how should LSTM states be reconsturcted?

This is mostly to have a better initialization for the lstm states than constant or random values. An alternative that is especially relevant for off-policy algorithms is to use warmup steps (see R2D2 paper) to initialize the lstm states before doing any gradient update. It requires however to store more data and suppose that the episode is long enough to perform those steps.

See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/201 for some pointers for SAC.

DeepRowLie commented 5 months ago

thanks