Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
494 stars 174 forks source link

PPORecurrent mini batch size inconsistent #113

Open b-vm opened 2 years ago

b-vm commented 2 years ago

I am using PPORecurrent with the RecurrentDictRolloutBuffer. In https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/103 it's mentioned that batch size is intended to be constant size. However this seems not to be the case.

I did some experiments where I print the action batch sizes coming from the _get_samples() function.

Exp 1: batch_size < sequence length

Initialize PPORecurrent with batch_size = 10. I am sampling only 2000 steps for debugging purposes. The mean sequence length is 32.3

image

It looks like whenever n_seq = 2 and it appends 2 different sequences the batch size is higher than the specified 10.

Exp 2: batch_size > sequence length

Initialize PPORecurrent with batch_size = 100 The mean sequence length this time is 31.7

image

The batch size is now always higher than the specified 100, and different every time.

Is this the intended behavior? It seems wrong to me since it is stated explicitly in the aforementioned issue that batch size is intended to be constant:

Actually no, the main reason is that you want to keep a mini batch size constant (otherwise you will need to adjust the learning rate for instance).

Also, now that we are at it, what is the reason for implementing mini batching, instead of feeding batches of whole sequences to the model?

System Info Describe the characteristic of your environment:

araffin commented 2 years ago

Hello,

The batch size is now always higher than the specified 100, and different every time.

my guess is that it is because of padding/masking.

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/a9735b9f317be4283e56d221e19087b926ca9ec0/sb3_contrib/ppo_recurrent/ppo_recurrent.py#L369

rollout_data.observations[mask].shape should be constant and equal to the desired value.

what is the reason for implementing mini batching, instead of feeding batches of whole sequences to the model?

if you feed all sequences, then the mini batch size is not constant, which makes it harder to tune hyperparameters. Or, as done in openai baselines, it imposes some restrictions on the number of steps, minibatches and environments.

If you feed sequences, you also don't know in advance how much memory you need for the storage, which makes it less efficient (related to https://github.com/DLR-RM/stable-baselines3/issues/1059) as you will need to use list and not fixed size numpy array.

b-vm commented 2 years ago

Thanks for the answers! I checked and you are right, the masking always yields the correct sized obs/act. I am still a bit confused on the purpose of padding in this case, as the data is fed sequentially into the model (seq_length * obs_size), instead of in 3D batches like (batch_size * seq_length * obs_size). Is there any resource I can read through to learn more about this?

araffin commented 2 years ago

the purpose of padding in this case, as the data is fed sequentially into the model (seq_length * obs_size)

What you want to feed is (batch_size, obs_shape) sequentially: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/a9735b9f317be4283e56d221e19087b926ca9ec0/sb3_contrib/common/recurrent/buffers.py#L231

So, if array could be of any sizes, then you would feed (n_seq, seq_length, obs_shape) where seq_length vary from one sequence to another. However, it is not possible (and not efficient if you replace numpy/torch array with list) to have an axis with variable size, so we padded each sequence in order to have a fixed size (max_seq_length).

This shape is needed here: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/a9735b9f317be4283e56d221e19087b926ca9ec0/sb3_contrib/common/recurrent/policies.py#L177-L179

Is there any resource I can read through to learn more about this?

not really... but you can compare that implementation to SB2 implementation: https://github.com/hill-a/stable-baselines/blob/14630fbac70aaa633f8a331c8efac253d3ed6115/stable_baselines/ppo2/ppo2.py#L368-L380

In SB2, there is no padding but there is a constraint on the number of steps and number of minibatches, so that the sequence length is always the same.

b-vm commented 1 year ago

Thanks for your elaborate reply. It makes much more sense now.

Although my models are training very well with SB3, I was still running into problems where it was taking way too long to train. PPORecurrent with an LSTM was running at only 104 fps on a 24 core machine, against normal PPO with an MLP at around 1000 fps on the same machine. Some profiling showed that the majority of comp time was spent on the backward pass. I was only able to speed it up to 181 fps by moving to a GPU machine (12 core + v100), which is still too slow for my project.

So in an effort to try and speed up training I forked sb3-contrib and implemented a version that trains on batches of whole sequences, by feeding a (n_seq, seq_length, obs_shape) batch directly to the Torch api, instead of using the Python for loop in _process_sequence.

This made the backward pass 4 times faster. I tested this on my laptop GPU, the backward pass went from 20s to 5s. I then ran this in a few configurations on our cluster to find out what the real speedup is in terms of fps, and to confirm that the models actually are able to train well. It achieved an fps of around 500, on the same 12 core + v100 machine. About 2.8 times faster than the 181 fps before. Also the models achieved the same rewards with similar trajectory. Results of two 24h runs are shown below, green is standard sb3-contrib with 512 batch size, blue is my fork of sb3-contrib with batches of 8 whole sequences.

image

However, it is not possible (and not efficient if you replace numpy/torch array with list) to have an axis with variable size, so we padded each sequence in order to have a fixed size (max_seq_length).

I was able to do this without resorting to lists, keeping the same setup for the rolloutbuffer as sb3. It is just a matter of indexing. The issue regarding variable sequence lengths in a batch is solved by padding sequences so the tensor becomes a cube a again. Below is the code I used to generate the batches:

# yields batches of whole sequences, shape: (batch_size, sequence_length, data_length)
for indices in batch_sampler:
    obs_batch = {}
    for key in self.observations:
        obs_batch[key] = pad_sequence([th.tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True)

    actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True)
    old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True)
    old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True)
    advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True)
    returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True)
    masks_batch = pad_sequence([th.ones_like(r) for r in returns_batch], batch_first=True)

    yield RecurrentDictRolloutBufferSequenceSamples(
        observations={key:th.swapaxes(obs_batch[key], 0, 1) for key in obs_batch},
        actions=th.swapaxes(actions_batch, 0, 1),
        old_values=th.swapaxes(old_values_batch, 0, 1),
        old_log_prob=th.swapaxes(old_log_probs_batch, 0, 1),
        advantages=th.swapaxes(advantages_
        returns=th.swapaxes(returns_batch, 0, 1),
        masks=th.swapaxes(masks_batch, 0, 1)
    )

If you are interested to see more about this, here is my fork. The main changes to look at are in buffers.py for the RecurrentSequenceDictRolloutBuffer, and any function with "whole_sequence" in the name. Also, I can run more benchmarks if you want.

Its a crude implementation just to try and test how well it works so I only implemented functionality I need for my project. If you're interested in this method I could do a proper implementation and submit a pull request for it. From what I am seeing it is not only a lot faster, but the code is also much simpler.

araffin commented 1 year ago

with an LSTM was running at only 104 fps on a 24 core machine, against normal PPO with an MLP at around 1000 fps on the same machine.

This is expected. That's why you should try PPO with framestack first as we recommend in the doc. See also the small study PPO LSTM vs PPO framestack: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4

batch directly to the Torch api, instead of using the Python for loop in _process_sequence.

you mean using: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/c75ad7dd58b7634e48c9e345fca8ebb06af3495e/sb3_contrib/common/recurrent/policies.py#L182-L187 I guess.

If you do not take the done signal into account, you are not treating sequences properly (you are treating n_steps from one env as a big sequence, not matter how many episodes where in there).

onfirm that the models actually are able to train well.

good to hear, did you try on envs from https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4 ? It sounds like you might not need PPO LSTM.

b-vm commented 1 year ago

That's why you should try PPO with framestack first as we recommend in the doc.

I have indeed read that, however I am expanding on prior research so I am limited to LSTMs.

you mean using: stable-baselines3-contrib/sb3_contrib/common/recurrent/policies.py Lines 182 to 187 in c75ad7d I guess.

Yes you are right, that would probably also work very well in terms of speed, however that code will never be called when seq_length < batch_size.

If you do not take the done signal into account, you are not treating sequences properly (you are treating n_steps from one env as a big sequence, not matter how many episodes where in there).

I am using the done signal to cut the rollout buffer data into the original sequences, which are then batched.

did you try on envs from

Yes, I tried on gym BipedalWalker v3, which went from 752 fps to 1992 fps, a 2.6 times speedup: image

Orange is standard PPORecurrent and red is the 3d batching version. Results look quite similar but I didnt spend time on tuning hyperparams for the 3d batching version. Code for running this test can be found here

It looks like 3d batching is between 2.5 - 3 times faster than what is currently implemented in sb3_contrib, while keeping same learning behaviour. If you're interested I can run more tests with different params/different envs.

araffin commented 1 year ago

Yes, I tried on gym BipedalWalker v3, which went from 752 fps to 1992 fps, a 2.6 times speedup: If you're interested I can run more tests with different params/different envs.

I meant on the envs from the benchmark (https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4) that actually require recurrence to work (because speed was removed from the observation).

They are defined in the RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/import_envs.py#L47-L62

BipedalWalker v3,

mmh, BipedalWalker v3 doesn't require recurrence to be solve normally, and -90 is far from the optimal performance (around 300).

I am using the done signal to cut the rollout buffer data into the original sequences, which are then batched.

I need to take a look at that in more details. The speedup is definitely interesting but I'm afraid of the correctness of the implementation.

b-vm commented 1 year ago

Fair points. I reran the test on BipedalWalker-v3 with the proper hyperparams (from sb3 zoo), and also ran a test on PendulumNoVel-v1. Here are the results:

BipedalWalker-v3:

image Orange is sb3_contrib and blue is my fork. Final fps values are 1098 vs 6001 respectively, so about 5.5 times faster with similar performance.

PendulumNoVel-v1:

image Red is sb3_contrib, green, pink and blue are my fork with 2, 4 and 8 sequences per batch respectively. Final fps values are 339, 2397, 2790 and 3086. So anywhere from 7-9 times faster with similar performance. The results also match with the link you sent.

I can run the other envs in the link too if you want to see those, but these results are pretty convincing to me and match what I am observing in the Mujoco env I am working on. Also you could run them yourself by adding them to this script.

You can find my fork here if you want to take a closer look. So far I implemented it for both Box and Dict observation space and Box action space, other spaces probably don't work yet.

araffin commented 1 year ago

thanks for the additional results =), could you open a draft PR? (that would make it easier for me to review the code)

b-vm commented 1 year ago

Of course! Here it is: #118 Let me know if you want to see any changes.

araffin commented 1 year ago

Thanks for the PR, I think I finally understand why it works. This is indeed a nice way to accelerate PPO LSTM. I have some remarks though:

b-vm commented 1 year ago

This is indeed a nice way to accelerate PPO LSTM.

Cool! Glad to hear that.

you are probably never sampling the first sequence nor the last once (likely to create issues), see

Good catch. I fixed it so that it also samples the last sequence in the rollout buffer: https://github.com/b-vm/stable-baselines3-contrib/blob/18ace01f01957c8b61c07fe5a195200ac3b7c12b/sb3_contrib/common/recurrent/buffers.py#L441-L443

It should now sample all sequences in the rollout buffer. However, it will drop the final batch in an epoch if n_batches mod batch_size != 0 to prevent very small batches from causing updates. However the next epoch the batch indexes are randomized again, so it is unlikely the same sequences will be in that final batch again.

the batch size doesn't have the same meaning anymore, here it corresponds to the number of sequences sampled and it has a varied length (which is not great for the learning rate, but experiments seems to be still working)

Agreed. It would be more correct to adjust the learning rate to whatever the current batch size is. In a sense it is kind of like a crude form of learning rate decay now, because the effect single timesteps have on weight updates decreases as a function of episode length(roughly).

Would love to help out integrating if you decide on making this part of the package.

araffin commented 1 year ago

However, it will drop the final batch in an epoch if n_batches mod batch_size != 0 to prevent very small batches from causing updates.

I'm not sure to understand why the last batch would be much smaller than the rest (because you sample whole sequences every time).

Would love to help out integrating if you decide on making this part of the package.

I think I'm interested in integrating it (but not replacing current sampling). However, some work is still needed to make the implementation cleaner and more integrated with the current code (i.e. remove duplicated code). Once the code is ready, I will need to run additional benchmark to check the performance vs the current default.

b-vm commented 1 year ago

I'm not sure to understand why the last batch would be much smaller than the rest (because you sample whole sequences every time).

For example if it tries to sample batches of 8 sequences out of a total of 81 sequences, it will result in 10 full batches of 8 sequences, and 1 small batch with only 1 sequence. This is especially a problem when that particular sequence is very short, like only 1 or a few timesteps. Nice example here in the PyTorch docs

Also, although it is true that the batching algorithm never cuts up sequences and keeps them whole, some sequences are cut up anyway because they started during a previous rollout so the first part is just not present in the rolloutbuffer.

I think I'm interested in integrating it (but not replacing current sampling).

Cool! In that case, would you want to make it optional to enable this batching method? Or make it a separate algorithm? I agree that it needs work, the current implementation was just to test it out without breaking any other code.

When we decide on a general architecture I can start working on it.

araffin commented 1 year ago

For example if it tries to sample batches of 8 sequences out of a total of 81 sequences, it will result in 10 full batches of 8 sequences, and 1 small batch with only 1 sequence.

yes but sequences are of variable length (so you could sample 1 long sequence vs 3 short sequences).

would you want to make it optional to enable this batching method?

make it optional (at least try, we will make separate algorithms if it becomes too complex). I also need to check the issue with the flatten layer, I'm surprised it worked before.

b-vm commented 1 year ago

I integrated the code better with the existing code and removed the duplicates. Now about 120 lines smaller. Enabling it is still optional through the whole_sequence flag.

Curious what you think regarding the flatten layer.

It also seems to me that we could simplify the _process_sequence function if we make whole sequences the default. I believe _process_sequence will only have to deal with single step sequences now that the forward passes are dealt with in evaluate_actions_whole_sequence.

beuricav commented 10 months ago

No further work was done here? I'm curious to see here progress on the main sb3 contrib integration.

araffin commented 10 months ago

No further work was done here? I'm curious to see here progress on the main sb3 contrib integration.

Please have a look at https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/118

Help and further testing is needed.