Closed raymondchua closed 5 months ago
Hi there! I'm not exactly sure what you mean, but I'll try to answer you as best as possible.
If your environment exists in Gymnax: Simply pass the environment name to any config class, e.g.
config = PPOConfig.create(env="Breakout-MinAtar")
If your environment does not exist in Gymnax, but adheres to its interface (i.e. implements step(self, key, state, action, params)
, etc.) you can pass an instance of the environment to the config class, e.g.
custom_env = YourCustomGymnaxEnv(...)
config = PPOConfig.create(env=custom_env)
If your environment is written in pure Jax (all its functions are jittable), but does not adhere to the Gymnax interface, you should create a wrapper around it. I have written such a wrapper for Brax environments, you can find it in rejax/brax2gymnax.py. You can then pass the wrapped environment as in 2.
If your environment is not written in pure Jax, the wrapper becomes slightly more complicated, since you would have to wrap all of its functions in Jax callbacks (jax.pure_callback
or jax.experimental.io_callback
, where applicable). Note that this (kind of) defeats the purpose of using Rejax, since the callbacks will be called sequentially, not taking advantage of the parallelism that vmap
and pmap
provide.
I hope this helps! Let me know if you have any more questions.
Awesome, thanks for your reply @keraJLi. That is the information I am looking for! I am using Craftax which I think will fit for #2.
I'm glad I could help!
Hi, What is the workflow like for using this library on other environments that are not in config?
I guess we would need to make a wrapper for the environment?