instadeepai / flashbax

⚡ Flashbax: Accelerated Replay Buffers in JAX
https://instadeepai.github.io/flashbax/
Apache License 2.0
209 stars 10 forks source link

Needing help understanding trajectory buffer #29

Closed tobiaswuerth closed 4 months ago

tobiaswuerth commented 4 months ago

Similarly to #23 I'm not sure exactly how to use this library or find the correct buffer type for me. I've read through the documentation and experimented with the sample colabs.

Looking at the trajectory buffer, what confuses me is that one can either specify max_size or max_length_time_axis but not both. Depending on what I set the buffer state shape changes accordingly. But I haven't been able to get it to exactly what I need.

For example, let's say I have an environment which returns observations with shape (84,84,3). And I want to collect trajectories with at least 40 steps. Additionally I want to store 100'000 trajectories in my buffer. In the past I setup a buffer with shape (100000,40,84,84,3). I can achieve this shape by setting add_batch_size=100000 and max_length_time_axis=40, but this requires me to add batches with shape 100000, which I obviously don't want to, since I maybe run 100 environments in parallel but not 100000. Adding batches with a smaller shape raises an exception Tree leaf 'obs' has a shape prefix different from expected.

Either I misunderstand something or this usecase cannot be handled. Can you provide me with some insights?

In addition to that, when handling episodes not all episodes yield the same amount of samples. Some episodes terminate before the minimum required length is reached (say after 13 steps, but trajectories with 40 steps are desired), therefor the episode should either a) be discarded or b) zero-padded. Furthermore if an episode terminates on any other step (say 56) this would result in 1 full trajectory and 1 partially filled trajectory. Ideally one could reuse the last steps from the first full trajectory to complete/prefix the 2nd trajectory and therefor make it complete as well. Alternatively, if max_episode_length of the environment is not in the thousands, one could just store all steps in one dimension instead of splitting it up into seperate trajectories, which would allow for flexible sampling of n timesteps from/to any timestep within that one episode (not just 0-40, 41-80 etc., but rather 11-51, 27-67 etc., any 40-step range really). This is regarding samples not being allowed to cross episode boundaries.

Is logic like this built in? I haven't seen anything regarding this.

Thanks for the info!

EdanToledo commented 4 months ago

Hey @tobiaswuerth, thanks for the issue. Let's break down what you're trying to achieve.

You want to sample trajectories of length 40 from a buffer, and you're dealing with variable-length trajectories. Here are your options:

  1. Padding: This needs to be handled by your code. Flashbax does not perform padding as it focuses on auto-reset episodes and rollouts managed by the user with discount/done/terminal variables.
  2. Discarding: Similar to padding, discarding logic must be implemented by you.
  3. Reusing Steps: Flashbax supports this. Data is added contiguously, so additional steps from one trajectory can precede the next. If you set sample_sequence_length = 40 and period=40, you will reuse steps from previous trajectories.

Clarification on Buffer Settings:

Episode Boundaries:

Handling episode boundaries (to prevent crossing) can be challenging in Flashbax/JAX. The periodicity argument is a potential solution, but masking/padding should be managed by the user. Usually in the normal RL use case this is handled by masking states after a terminal transition is reached.

tobiaswuerth commented 4 months ago

Thanks @EdanToledo for the prompt response!

Initializing the buffer like this:

buffer = fbx.make_trajectory_buffer(
    min_length_time_axis=40,
    sample_batch_size=64,
    max_size=40000,
    add_batch_size=128,
    sample_sequence_length=40,
    period=1
)

with the observation shape described, I get a state shape for experiences of (128, 312, 84, 84, 3) where 128 is the batch_size (e.g. for parallel environments) and the 312 is the dynamically calculated size based on max_size to reach the specified 40000 steps (128*312=39936~=40000). This, I understand.

Assuming I only add 1 timestamp at a time (i.e. (128, 1, 84, 84, 3)) one for each environment each step this would presumably just add it sequentially until the 312 max time axis is reached and the index restarts at 0. Correct?

Can you explain to me what the period argument does exactly? Assuming I would specify period=10, does this mean that sampled sequences will only start from a multiple of 10? e.g. 10:, 20:, 30: etc. on the time axis?

As far as I can follow this logic this does generally ensure that t+1 comes after t in the buffer, but this does not hold if the end of the buffer is reached, right? and it also does not guarantee that no episode-crossing occurrs, unless I add extra logic in the algorithm that always adds a fixed amount of steps (including padding if needed) in conjunction with period=n such that always self-containing episodes are retrieved. Working with RNNs the consistency across time is crucial for me.

If all of this is correct then I start to understand the design choice how you treat the time axis, because up until now for me at least this had semantically a different meaning.

EdanToledo commented 4 months ago

Assuming I only add 1 timestamp at a time (i.e. (128, 1, 84, 84, 3)) one for each environment each step this would presumably just add it sequentially until the 312 max time axis is reached and the index restarts at 0. Correct?

Yep!

Can you explain to me what the period argument does exactly? Assuming I would specify period=10, does this mean that sampled sequences will only start from a multiple of 10? e.g. 10:, 20:, 30: etc. on the time axis?

Sure, so basically the period dictates the number of timesteps between possibly sampled sequences so basically the degree of overlap of data. Practically this does mean what you state.

As far as I can follow this logic this does generally ensure that t+1 comes after t in the buffer, but this does not hold if the end of the buffer is reached, right?

so here is where it becomes a little nuanced, t+1 will always come after t, in the case of a wrap around, we do not let you sample the sequence of data that would cause a sudden drop in temporal consistency. What this basically means is that you cannot sample a sequence whereby the starting index of that sample plus the sample sequence length modulo the max time length axis is greater than the current buffer write index. This means that all sampled sequences will always have temporal consistency. If this is not the case, there is a bug in the code but we have written tests and these aspects perform as expected.

and it also does not guarantee that no episode-crossing occurrs, unless I add extra logic in the algorithm that always adds a fixed amount of steps (including padding if needed) in conjunction with period=n such that always self-containing episodes are retrieved. Working with RNNs the consistency across time is crucial for me.

Regarding episode boundary crossing, yes, there is no guarantee that you wont sample sequences that cross episode boundaries without extra logic etc. I'm not exactly sure of the nature of the work you are performing with RNNs but my recommendation would be to look into performing hidden state resets internally in the architecture based on terminal flags. This allows for sampling sequences of any length and not caring about whether or not there is an episode boundary as all the data is valid. For an example check the rec_ppo.py file in either Stoix (single agent) or Mava (multi agent).

tobiaswuerth commented 4 months ago

Thanks @EdanToledo for taking the time and providing these high quality answers to my questions! Highly appreciated.

I think I understand now, thanks for explaining it to me.

Thank you also for referencing your other project, I think I found the place where you handle RNN hidden state resets. This for sure looks interesting. I'm just getting into Jax and don't know lot of the ecosystem yet. Your code is quite extensive, I'll need more time to dig through it.

The project I'm working on needs to calculate a transformed retrace loss, essentially compound future rewards for td-error calculation. Having to deal with potentially multiple episodes per trajectory would make the calculation more involved and the results unreliable since not every calculation consideres the same amount of future steps. Or maybe I just don't know how to do it yet.

I'll see what I come up with. Thank you :)