Closed StoneT2000 closed 1 year ago
Hi @StoneT2000, is this for a gymnax env?
The gymnax step wrapper function simultaneously resets and steps the environment in order to perform autoresets (trying to circumvent if-conditionals when done). Lax.select then performs a type of masked addition. But it requires the āsteppedā env state and reseted āenv stateā to share the same data types.
So you have to make sure that the output of reset_env and step_env always have the same data type. This should be covered for all ānativeā gymnax envs by the test suite.
oops I was being very dumb haha, thanks.
Im currently using the latest jax version (0.4.8) and it reports that
Any idea of what's going on? Seems like using
jnp.where
instead is a simple fix.