keraJLi / rejax

Apache License 2.0
154 stars 7 forks source link

Workflow for using this library on other environments? #6

Closed raymondchua closed 5 months ago

raymondchua commented 5 months ago

Hi, What is the workflow like for using this library on other environments that are not in config?

I guess we would need to make a wrapper for the environment?

keraJLi commented 5 months ago

Hi there! I'm not exactly sure what you mean, but I'll try to answer you as best as possible.

  1. If your environment exists in Gymnax: Simply pass the environment name to any config class, e.g. config = PPOConfig.create(env="Breakout-MinAtar")

  2. If your environment does not exist in Gymnax, but adheres to its interface (i.e. implements step(self, key, state, action, params), etc.) you can pass an instance of the environment to the config class, e.g.

    custom_env = YourCustomGymnaxEnv(...)
    config = PPOConfig.create(env=custom_env)
  3. If your environment is written in pure Jax (all its functions are jittable), but does not adhere to the Gymnax interface, you should create a wrapper around it. I have written such a wrapper for Brax environments, you can find it in rejax/brax2gymnax.py. You can then pass the wrapped environment as in 2.

  4. If your environment is not written in pure Jax, the wrapper becomes slightly more complicated, since you would have to wrap all of its functions in Jax callbacks (jax.pure_callback or jax.experimental.io_callback, where applicable). Note that this (kind of) defeats the purpose of using Rejax, since the callbacks will be called sequentially, not taking advantage of the parallelism that vmap and pmap provide.

I hope this helps! Let me know if you have any more questions.

raymondchua commented 5 months ago

Awesome, thanks for your reply @keraJLi. That is the information I am looking for! I am using Craftax which I think will fit for #2.

keraJLi commented 5 months ago

I'm glad I could help!