A big reorganization to make the dataset generation much more efficient with brax envs. The upshot: previously dataset generation for the cart-pole took >70s. Now it takes <10s.
Also reorganizes how we generate data in a more sane way:
Langevin sampling during training is defined in generation.py rather than utils.py
Fixes PRNG bug that was leading to auto-correlation in the training set (yikes!)
We anneal the noise at every sampling step, rather than running Langevin in a "staircase" pattern
The whole observation sequence and total cost are stored in the dataset. This should make visualizations much easier: no need to roll out the control tape again just to see what went on.
A big reorganization to make the dataset generation much more efficient with brax envs. The upshot: previously dataset generation for the cart-pole took >70s. Now it takes <10s.
Also reorganizes how we generate data in a more sane way:
generation.py
rather thanutils.py
dataset
. This should make visualizations much easier: no need to roll out the control tape again just to see what went on.