pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[Discussion] Remember TorchRL: the state of memory in TorchRL #2325

Open matteobettini opened 1 month ago

matteobettini commented 1 month ago

Remember TorchRL: the state of memory in TorchRL

Hello! This is a discussion post to recap the state of memory models in TorchRL: what's doable, what's not doable, what is the way to do things, and what is missing.

Why

The goal of this post is to outline what we can imporve in the library to allow the users to have access to state-of-the-art memory models to tackle partially observable (PO) problems

DIsclaimer

I am not a memory guy. In fact, I cannot even remember any birthdays and have to put them in my calendar. The idea of this post is to stimulate discussion with memory researchers. So if anything I write here is incorrect, please point it out so I can learn and fix it.

The state of memory in RL

Traditional architectures are GRUs and LSTMs (available in this library). Despite GRUs have been shown to have good results in PO problems https://arxiv.org/abs/2303.01859, they have problems related to the way they process the sequence.

In fact, to run LSTMs and GRUs parallely over time you need a fixed sequence length without any terminations within it. The way this has been traditionally tackled (and it is tackled in this library) is to split and pad trajectories to a fixed length that does not contain any dones.

This has some major issues. Most importantly: the length of the sequence is an hyperparamter that will affect directly the memory length. A low value will impede remembering things too far in the past and a high value will cause high padding and inefficiency.

Recently, a new class of sequence models, sometimes called Linear Recurrent Models or Linear Transformers has been introduced. These models can be run parallely in the time dimension with subquadratic space complecity. Examples are S5 and Fast and Forgetful Memory

Most importantly, these models do not require a fixed sequence length (and thus padding) and can be run on consecutive terminating trajectories.

To be precise, LSTM and GRU could be also utilised without padding, but this would mean calling them in a for loop in their "Cell" version, which would lead to better results, but higher complecity (as they could not be batched over time)

What is available in TorchRL

Models:

Tensordict modules:

Tutorials:

Replay buffer:

What is doable in TorchRL

What is not available in TorchRL

Implementing Linear Recurrent Models efficiently in torch is currently not possible due to the lack of Parallel Associative Scan (https://github.com/pytorch/pytorch/issues/95408), a feature that, on the contrary, is available in JAX and has allowed progress in RL memory research.

These models could still be implemented in their "cell" version (for loop instead of parallel scan). This would be quite inefficient although it would lead to better performance wrt traditional architectures.

What are the next steps?

Easy

I believe that we should make it easier for users to approach memory models in torchrl. This could be done by adding a new tutorial that shows how to use the SliceSampler in a memory context (maybe with a GRU since the other tutorial uses LSTM).

It would also be cool to have a tutorial that shows users how to use the current memory models without padding (in their cell version). Which ia a better although more inefficient implementation.

Medium

We could implement Linear Recurrent Models in their inefficient version, avoiding the use of the Parallel Associative Scan

Hard

Introdice Parallel Associative Scan to pytorch (https://github.com/pytorch/pytorch/issues/95408), allowing us to code the state-of-the-art memory models and catch up with JAX.

Conclusions

I hope to have given a complete picture of the state of memory in TorchRL, if anything is missing or incorrect, please point it out and I will update the comment

I am happy to put coding work into a direction once we have identified one. My only limitation is that I would probably not be able to add torch c++ or cuda kernels without a signficant learning experinece

cc @vmoens @albertbou92 @smorad @EdanToledo

vmoens commented 1 month ago

Thanks for posting this!

What kind of speed up does associative scan provides for these models in jax? (jax.jit vs torch eager / jax eager) To know what would be a satisfactory implementation.

vmoens commented 1 month ago

For instance, with torchrl and tensordict nightly and torch 2.4, I run the following code

from torchrl.modules import GRUCell, GRU
import torch
from torch.utils.benchmark import Timer
from torchrl.modules.tensordict_module.rnn import GRUCell
device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
V = 3  # vector size
B = 4
N_IN = 32
N_OUT = 200
T = 10

gru_cell = GRUCell(input_size=N_IN, hidden_size=N_OUT, device=device)

def call_gru(x, h):
    h_out = gru_cell(x, h)
    return h_out
batched_call = torch.vmap(call_gru)

x = torch.randn(V, B, N_IN, device=device)
h0 = torch.zeros(V, B, N_OUT, device=device)
with torch.no_grad():
    h1 = batched_call(x, h0)

input = torch.randn(V, B, T, N_IN, device=device)
h0 = torch.zeros(V, B, N_OUT, device=device)

def time_batched_call_gru(x, h):
    hs = []
    for x_ in x.unbind(-2):
        h = batched_call(x_, h)
        hs.append(h)
    return torch.cat(hs, -2)

print(Timer("time_batched_call_gru(input, h0).sum().backward()", globals=globals()).adaptive_autorange())

def time_call_gru(x, h):
    hs = []
    for x_ in x.unbind(-2):
        h = gru_cell(x_, h)
        hs.append(h)
    return torch.cat(hs, -2)

batched_time_call_gru = torch.vmap(time_call_gru)

print(Timer("batched_time_call_gru(input, h0).sum().backward()", globals=globals()).adaptive_autorange())

time_batched_call_gru_c = torch.compile(time_batched_call_gru, mode="reduce-overhead", fullgraph=True)
# time_batched_call_gru_c = torch.compile(time_batched_call_gru, fullgraph=True)
time_batched_call_gru_c(input, h0)
time_batched_call_gru_c(input, h0)
print(Timer("time_batched_call_gru_c(input, h0).sum().backward()", globals=globals()).adaptive_autorange())

batched_time_call_gru_c = torch.compile(batched_time_call_gru, mode="reduce-overhead", fullgraph=True)
# batched_time_call_gru_c = torch.compile(batched_time_call_gru, fullgraph=True)
batched_time_call_gru_c(input, h0)
batched_time_call_gru_c(input, h0)
print(Timer("batched_time_call_gru_c(input, h0).sum().backward()", globals=globals()).adaptive_autorange())

and get these very nice results with compile (6x faster than eager on H100 machine)

<torch.utils.benchmark.utils.common.Measurement object at 0x7fd95d72ca30>
time_batched_call_gru(input, h0).sum().backward()
  Median: 5.48 ms
  IQR:    0.05 ms (5.46 to 5.51)
  4 measurements, 10 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd95d72d810>
batched_time_call_gru(input, h0).sum().backward()
  Median: 5.79 ms
  IQR:    0.07 ms (5.76 to 5.83)
  4 measurements, 10 runs per measurement, 1 thread
/home/vmoens/.conda/envs/torchrl/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:150: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd9235aca60>
time_batched_call_gru_c(input, h0).sum().backward()
  Median: 861.05 us
  IQR:    68.60 us (834.74 to 903.35)
  9 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd8f6f1c3d0>
batched_time_call_gru_c(input, h0).sum().backward()
  Median: 968.07 us
  IQR:    84.75 us (945.94 to 1030.69)
  11 measurements, 1 runs per measurement, 1 thread
matteobettini commented 1 month ago

Thanks @vmoens, it is nice to see this speed up. Just to understand, what are we vmapping over? I imagine we cannot vmap over time as later elements in the for loop will depend on earlier ones and if i understant correctly the normal GRUCell can aready deal with a batch size. Also what we need to benchmark is the scaling in the time dimension in the presence of resets (scaling to higher Ts, where the for loop will start to show its limitations).

In any case, this for loop implementation is how I would implement the models. The idea is that in jax the associative scan will leverage the associative property of the recurrent computation and thus it should be way faster than a for loop.

I agree we need to benchmark this. One thing I could do is implement a recurrent model (e.g., S5) with the for loop above and compare it with the jax counterpart.

@smorad if you can help at all in the benchmarking of a torch for vs a jax parallel associative scan that would help a lot, otherwise I'll try to do my best

vmoens commented 1 month ago

The VMAP is done over the batch (first dimension). In fact LSTM isn't batched over time, it's just a c++ / cudnn kernel that runs faster than a for loop.

I don't need a very precise benchmark for now, I just need a good description of the model and a rough idea of the speed up from eager for loop to jit + associative scan

smorad commented 1 month ago

Thanks for posting this!

What kind of speed up does associative scan provides for these models in jax? (jax.jit vs torch eager / jax eager) To know what would be a satisfactory implementation.

Runtime becomes $O(\log n)$ instead of $O(n)$ for a sequence of length $n$ (given that the GPU has $n$ parallel threads).

smorad commented 1 month ago

From https://openreview.net/forum?id=KTfAtro6vP, which uses torch <2.0 and garbage 2080TIs. Your results with an updated version of torch and better GPUs probably look significantly different. But maybe ~1 order of magnitude faster?

Screenshot 2024-07-29 at 15 35 57
albertbou92 commented 1 month ago

I am happy to help too.

I agree that a SliceSampler tutorial would be nice.

As I understand the current QDN tutorial stores UNPADDED fix-length trajectories that can contain done flags. Then when the LSTMModule processes them, splits and pads the data into separate trajectories to process them in a single call. The main problem of the tutorial approach is that the stored trajectories are fixed. This is not ideal as we would want more diversity in the sampled trajectories.

I am happy to give a shot at a SliceSampler tutorial that solves the previous problem. I haven't used it before, it's a good opportunity to try it out.

To me it also makes sense to provide and explain a script with LSTM/GRU cell speedup execution with torch.compile if possible.

I have no experience with these latest memory models that do not require a fixed sequence length, so I guess I need to catch up a bit in that direction..

Would be super interesting to train on the same problem with the different approaches and benchmark and speed and performance.

matteobettini commented 1 month ago

It would be nice to also have modules that do not split and pad internally but process in a for loop reading the is_init flag. I have implemented a GRU like this for benchmarl (https://github.com/facebookresearch/BenchMARL/blob/6ee79436f268e65488ee0ce8b3152174afd7c029/benchmarl/models/gru.py#L30-L63)

albertbou92 commented 1 month ago

what do you mean when you say that this approach leads to more correct results? with padded sequences the final agent performance is worse? or is it just because processing padded time steps is inefficient?

regarding the code, I have another question, but I might be wrong. For long sequences wouldn't it be faster to reshape your [B, T] to [B * T], split in into consecutive trajectories, pass them sequentially through a GRU module and finally reshape back? since it runs the for loop on c++ / cudnn as Vincent mentioned.

smorad commented 1 month ago

For long sequences wouldn't it be faster to reshape your [B, T] to [B * T], split in into consecutive trajectories, pass them sequentially through a GRU module and finally reshape back? since it runs the for loop on c++ / cudnn as Vincent mentioned.

Unlikely, because a GRU is $O(n)$. In this case, you are choosing between $O(T)$ and $O(B T)$ time complexity. In fact, this is one of the reasons transformers were developed -- long sequences are costly with an LSTM/GRU.

matteobettini commented 1 month ago

what do you mean when you say that this approach leads to more correct results? with padded sequences the final agent performance is worse? or is it just because processing padded time steps is inefficient?

there is a nice explaination here on the problems of padding https://arxiv.org/abs/2402.09900

albertbou92 commented 1 month ago

what do you mean when you say that this approach leads to more correct results? with padded sequences the final agent performance is worse? or is it just because processing padded time steps is inefficient?

there is a nice explaination here on the problems of padding https://arxiv.org/abs/2402.09900

very nice paper!

matteobettini commented 1 month ago

Ok as general comment from this discussion, when i will have time, I'll try to implment S5 for torchrl, then, If and when the asosciative scan lands, we can rewrite it with that

albertbou92 commented 1 month ago

And what about your current implementation of GRU with the for loop? is that practical to run in terms of time? can it be accelerated with compile?

matteobettini commented 1 month ago

It is basically like your implementation with the difference that is_init is passed to the forward and in the time for loop you check it to see if you need to set the hidden states to 0. You can see them in BenchMARL models.

Compile totally fails for me as I then need to vmap the model multiple times and compile complains with some messages I don't understand so I gave up. But the basic module is compilable (similar to yours). I only got slowdowns after compilation though.