RobertTLange / gymnax

RL Environments in JAX šŸŒ
Apache License 2.0
613 stars 61 forks source link

Potential bug due to lax.select usage in step function #52

Closed StoneT2000 closed 1 year ago

StoneT2000 commented 1 year ago

Im currently using the latest jax version (0.4.8) and it reports that

  File "/home/stao/mambaforge/envs/robojax_brax/lib/python3.8/site-packages/gymnax/environments/environment.py", line 43, in step
    state = jax.tree_map(
  File "/home/stao/mambaforge/envs/robojax_brax/lib/python3.8/site-packages/gymnax/environments/environment.py", line 44, in <lambda>
    lambda x, y: jax.lax.select(done, x, y), state_re, state_st
TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

Any idea of what's going on? Seems like using jnp.where instead is a simple fix.

RobertTLange commented 1 year ago

Hi @StoneT2000, is this for a gymnax env?

The gymnax step wrapper function simultaneously resets and steps the environment in order to perform autoresets (trying to circumvent if-conditionals when done). Lax.select then performs a type of masked addition. But it requires the ā€œsteppedā€ env state and reseted ā€œenv stateā€ to share the same data types.

So you have to make sure that the output of reset_env and step_env always have the same data type. This should be covered for all ā€œnativeā€ gymnax envs by the test suite.

StoneT2000 commented 1 year ago

oops I was being very dumb haha, thanks.