vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.26k stars 602 forks source link

Are you interested in PRs for improvements in performance of PPO LSTM script? #276

Open thomasbbrunner opened 1 year ago

thomasbbrunner commented 1 year ago

Problem Description

The current PPO with LSTM script ppo_atari_lstm.py uses sequential stepping through the LSTM, i.e. each step of the sequence in processed individually:

for h, d in zip(hidden, done):
    h, lstm_state = self.lstm(
        h.unsqueeze(0),
        (
            (1.0 - d).view(1, -1, 1) * lstm_state[0],
            (1.0 - d).view(1, -1, 1) * lstm_state[1],
        ),
    )
    new_hidden += [h]

This method is very slow compared to sending the entire sequence of observations to the LSTM:

h, lstm_state = self.lstm(hidden, lstm_state)

This usually cannot be done in RL, as we have to reset the hidden states when an episode ends.

Other implementations of PPO use a trick, which is to split a sequence containing several trajectories into several sequences that contain only one trajectory. This is accomplished by splitting the input sequence everywhere where there's a done and padding the rest of the sequence. This can be visualized as:

Original sequences: [ [a1, a2, a3, a4 | a5, a6],
                      [b1, b2 | b3, b4, b5 | b6] ]

Split sequences:[ [a1, a2, a3, a4],
                  [a5, a6, 0, 0],
                  [b1, b2, 0, 0],
                  [b3, b4, b5, 0],
                  [b6, 0, 0, 0] ]

With this trick it is possible to make a single call to the LSTM to process multiple sequences and batches.

Proposal

I implemented a version of the script that uses this trick to process sequences with one call. In my setup, it led to a 4x improvement in training speed. However, it comes with a higher memory usage (about 2x in my setup). The final performance of the policy is similar to the original script.

Would you be interested in adding this script to the repo? Should I make a PR to create a new file using this trick?

vwxyzjn commented 1 year ago

Thanks @thomasbbrunner. This looks like a really cool trick. CC @araffin. This idea seems to be related to https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53#issuecomment-1093361537.

Would you be interested in comparing this technique in JAX? We are gradually adopting JAX, which is much faster (see https://github.com/vwxyzjn/cleanrl/pull/227#issuecomment-1243112578). The current PPO + LSTM implementation is slow due to the python for loop, but we might be able to speed it up considerably using JIT and JAX.

In that sense, I would be interested to see the performance difference between doing split sequences and original sequences using JAX — if there is no significant difference when using JIT, then it might not be worth doing this technique.

thomasbbrunner commented 1 year ago

Thanks for the interest!

I am not very familiar with JAX. Currently, I don't have the time to look into this, but I am interested in it and I will make time for it in the near future.

vwxyzjn commented 1 year ago

Sounds good thank you!