RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
585 stars 54 forks source link

Jittable `Environment` class #18

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Similar to how distributions work in distrax, I want to change the API to work with a jittable environment class. E.g.

env = gymnax.make('env_name')
obs, state = env.reset(key)
obs, state, reward, done, info = env.step(key, state, action)

Hence, the environment parameters are "absorbed" in the class instance. This should not be too difficult as long as we are careful about the pytree.

References:

RobertTLange commented 3 years ago

Addressed in dd3e7288fb9fcc094442e48b7050de90d58d0d82.