ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
33 stars 2 forks source link

Environment support #7

Closed edoust closed 3 weeks ago

edoust commented 1 month ago

Hey, great work on this one!

Using your codebase I trained on a few jittable environments (mainly gymnax and self-written), which works quite well, however it requires some changes to the training script. It also requires more environment wrappers for the (tensorboard) logging to work

Do you plan to have full support for a jax-based framework like gymnax or brax? If so, any thoughts on which library to support?

Having support for jittable environments (without gym wrappers) might make it possible to scan over the training loop

ShaneFlandermeyer commented 3 weeks ago

I know we have been discussing this in your pull request (#6), but I would like to summarize things here for posterity.

My goal is to make the TDMPC2 agent completely agnostic of the underlying environment library. We're doing this by giving the user full control of the observation encoder definition and by ensuring that the agent/world model operates only on Jax arrays, instead of higher-level (possibly environment-specific) objects/metadata.

On the other hand, the main driver script is not intended to support all environments. We designed it as a minimal example of how to use the library with common environment setups (currently DMControl and gymnasium) that can be easily hacked and extended for whatever use case one may have in mind.

This is just my general design philosophy and nothing is set in stone, so I'm happy to discuss this further if you would like! Otherwise, I will close this issue and merge your PR.