Vectorize the rollout loop. I could fit batch_size=10 of 2D RPF with GNS-10-64 on a 48GB GPU at seemingly no additional cost. Larger batches also fit, but then computation is linearly slower.
This should not break any previous functionality as we essentially just vectorize, i.e. batch_size=1 corresponds to what we had before.
EDIT (05.01.2024)
Running batch_size=1 takes 22 ms/step on the above setup (averaged over 100 batched rollouts of length 100 without reallocation)
Running batch_size=10 takes 28 ms/step.
=> roughly 8x speedup per generated trajectory on that setup.
Vectorize the rollout loop. I could fit batch_size=10 of 2D RPF with GNS-10-64 on a 48GB GPU at seemingly no additional cost. Larger batches also fit, but then computation is linearly slower.
This should not break any previous functionality as we essentially just vectorize, i.e. batch_size=1 corresponds to what we had before.
EDIT (05.01.2024) Running batch_size=1 takes 22 ms/step on the above setup (averaged over 100 batched rollouts of length 100 without reallocation) Running batch_size=10 takes 28 ms/step. => roughly 8x speedup per generated trajectory on that setup.