hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.17k stars 725 forks source link

Training LSTMs involves lots of data transformation #158

Open ernestum opened 5 years ago

ernestum commented 5 years ago

I looked at how exactly LSTMs are trained with PPO2 and found that a lot of unnecessary data transformations happen:

  1. Trajectories are sampled by the Runner. At the end of its run method data is flattened from the shape [num_steps, num_envs, x] to [num_steps * num_envs, x] after switching the first two dimensions.
    arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
  2. In the learn method of PPO2 a very hard to understand mechanism is used to shuffle the sampled trajectory data without mixing up states adjacent in time. It uses a very large flat_indices array.
    flat_indices = np.arange(self.n_envs * self.n_steps).reshape(self.n_envs, self.n_steps)
  3. In the optimization step, the data needs to be disentangled again using the batch_to_seq function, which reconstructs the data in the shape [n_steps, num_envs, x] again so we can build an LSTM graph (by the way this is the format in which the trajectories were sampled to begin with in step 1).
    input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
  4. For further processing, the data is converted back to the flat version using seq_to_batch
    rnn_output = seq_to_batch(rnn_output)

    All this seems to be overly complex and potentially slow to me. This is why I would like to open the discussion here on how matters could be improved. Please set your ideas free :-)

ernestum commented 5 years ago

My first thought was that the runner should keep the data untouched and we should feed it to the policy in the format [num_steps, num_envs, x]:

What do you think?

araffin commented 5 years ago

Yes, I completely agree that LSTM code is overcomplicated (and that is also the reason I avoid using recurrent policies for now ^^"...). However, I need a bit more time to give you insightful feedback. Ping me again in two weeks if I didn't answer you ;)

araffin commented 5 years ago

Referencing that PR here: https://github.com/openai/baselines/pull/859