Closed tobiaswuerth closed 4 months ago
Hey @tobiaswuerth, thanks for the issue. Let's break down what you're trying to achieve.
You want to sample trajectories of length 40 from a buffer, and you're dealing with variable-length trajectories. Here are your options:
sample_sequence_length = 40
and period=40
, you will reuse steps from previous trajectories.add_batch_size
to 100 and max_length_time_axis
to 40. This requires adding 100 trajectories in parallel.add_batch_size
to 50 and max_length_time_axis
to 80 to store two consecutive trajectories, reducing the need to add 100 trajectories at once.max_size
calculates the required buffer size based on transitions, not trajectories. For example, to store 1000 trajectories of 40 steps using 10 parallel environments, set max_size = 1000*40 = 40,000
. This would dynamically make max_length_time_axis = 4000
. Does this make sense?Handling episode boundaries (to prevent crossing) can be challenging in Flashbax/JAX. The periodicity
argument is a potential solution, but masking/padding should be managed by the user. Usually in the normal RL use case this is handled by masking states after a terminal transition is reached.
Thanks @EdanToledo for the prompt response!
Initializing the buffer like this:
buffer = fbx.make_trajectory_buffer(
min_length_time_axis=40,
sample_batch_size=64,
max_size=40000,
add_batch_size=128,
sample_sequence_length=40,
period=1
)
with the observation shape described, I get a state shape for experiences of (128, 312, 84, 84, 3)
where 128 is the batch_size (e.g. for parallel environments) and the 312 is the dynamically calculated size based on max_size to reach the specified 40000 steps (128*312=39936~=40000). This, I understand.
Assuming I only add 1 timestamp at a time (i.e. (128, 1, 84, 84, 3)) one for each environment each step this would presumably just add it sequentially until the 312 max time axis is reached and the index restarts at 0. Correct?
Can you explain to me what the period argument does exactly? Assuming I would specify period=10, does this mean that sampled sequences will only start from a multiple of 10? e.g. 10:, 20:, 30: etc. on the time axis?
As far as I can follow this logic this does generally ensure that t+1 comes after t in the buffer, but this does not hold if the end of the buffer is reached, right? and it also does not guarantee that no episode-crossing occurrs, unless I add extra logic in the algorithm that always adds a fixed amount of steps (including padding if needed) in conjunction with period=n such that always self-containing episodes are retrieved. Working with RNNs the consistency across time is crucial for me.
If all of this is correct then I start to understand the design choice how you treat the time axis, because up until now for me at least this had semantically a different meaning.
Assuming I only add 1 timestamp at a time (i.e. (128, 1, 84, 84, 3)) one for each environment each step this would presumably just add it sequentially until the 312 max time axis is reached and the index restarts at 0. Correct?
Yep!
Can you explain to me what the period argument does exactly? Assuming I would specify period=10, does this mean that sampled sequences will only start from a multiple of 10? e.g. 10:, 20:, 30: etc. on the time axis?
Sure, so basically the period dictates the number of timesteps between possibly sampled sequences so basically the degree of overlap of data. Practically this does mean what you state.
As far as I can follow this logic this does generally ensure that t+1 comes after t in the buffer, but this does not hold if the end of the buffer is reached, right?
so here is where it becomes a little nuanced, t+1 will always come after t, in the case of a wrap around, we do not let you sample the sequence of data that would cause a sudden drop in temporal consistency. What this basically means is that you cannot sample a sequence whereby the starting index of that sample plus the sample sequence length modulo the max time length axis is greater than the current buffer write index. This means that all sampled sequences will always have temporal consistency. If this is not the case, there is a bug in the code but we have written tests and these aspects perform as expected.
and it also does not guarantee that no episode-crossing occurrs, unless I add extra logic in the algorithm that always adds a fixed amount of steps (including padding if needed) in conjunction with period=n such that always self-containing episodes are retrieved. Working with RNNs the consistency across time is crucial for me.
Regarding episode boundary crossing, yes, there is no guarantee that you wont sample sequences that cross episode boundaries without extra logic etc. I'm not exactly sure of the nature of the work you are performing with RNNs but my recommendation would be to look into performing hidden state resets internally in the architecture based on terminal flags. This allows for sampling sequences of any length and not caring about whether or not there is an episode boundary as all the data is valid. For an example check the rec_ppo.py file in either Stoix (single agent) or Mava (multi agent).
Thanks @EdanToledo for taking the time and providing these high quality answers to my questions! Highly appreciated.
I think I understand now, thanks for explaining it to me.
Thank you also for referencing your other project, I think I found the place where you handle RNN hidden state resets. This for sure looks interesting. I'm just getting into Jax and don't know lot of the ecosystem yet. Your code is quite extensive, I'll need more time to dig through it.
The project I'm working on needs to calculate a transformed retrace loss, essentially compound future rewards for td-error calculation. Having to deal with potentially multiple episodes per trajectory would make the calculation more involved and the results unreliable since not every calculation consideres the same amount of future steps. Or maybe I just don't know how to do it yet.
I'll see what I come up with. Thank you :)
Similarly to #23 I'm not sure exactly how to use this library or find the correct buffer type for me. I've read through the documentation and experimented with the sample colabs.
Looking at the trajectory buffer, what confuses me is that one can either specify
max_size
ormax_length_time_axis
but not both. Depending on what I set the buffer state shape changes accordingly. But I haven't been able to get it to exactly what I need.For example, let's say I have an environment which returns observations with shape
(84,84,3)
. And I want to collect trajectories with at least40
steps. Additionally I want to store100'000
trajectories in my buffer. In the past I setup a buffer with shape(100000,40,84,84,3)
. I can achieve this shape by settingadd_batch_size=100000
andmax_length_time_axis=40
, but this requires me to add batches with shape 100000, which I obviously don't want to, since I maybe run 100 environments in parallel but not 100000. Adding batches with a smaller shape raises an exceptionTree leaf 'obs' has a shape prefix different from expected
.Either I misunderstand something or this usecase cannot be handled. Can you provide me with some insights?
In addition to that, when handling episodes not all episodes yield the same amount of samples. Some episodes terminate before the minimum required length is reached (say after 13 steps, but trajectories with 40 steps are desired), therefor the episode should either a) be discarded or b) zero-padded. Furthermore if an episode terminates on any other step (say 56) this would result in 1 full trajectory and 1 partially filled trajectory. Ideally one could reuse the last steps from the first full trajectory to complete/prefix the 2nd trajectory and therefor make it complete as well. Alternatively, if
max_episode_length
of the environment is not in the thousands, one could just store all steps in one dimension instead of splitting it up into seperate trajectories, which would allow for flexible sampling of n timesteps from/to any timestep within that one episode (not just0-40
,41-80
etc., but rather11-51
,27-67
etc., any 40-step range really). This is regarding samples not being allowed to cross episode boundaries.Is logic like this built in? I haven't seen anything regarding this.
Thanks for the info!