Closed taodav closed 6 months ago
Hi, thank you for spotting this bug!
If this is true and that this fixes it, we then need to bump the version of PacMan from PacMan-v0
to PacMan-v1
. Could you please make the following changes to the registry and documentation?
I've updated my commit to bump the version of PacMan to v1. Let me know if I've missed anything!
That's perfect, thanks! I'm struggling to reproduce the issue as I am not finding NaNs in different tests I've done. Would you have a small reproduction of the NaNs that you would be able to share? Thank you!
I did a bit of digging, and essentially if the function is jit
ed, then jnp.inf * False
returns 0 (which works with PacMan), whereas it should return NaN:
https://github.com/google/jax/issues/12233#issuecomment-1238401228
The NaN's don't show up in the action_mask
, but essentially zeros out the action_mask
for ghosts. I put a breakpoint at the return statement of check_ghost_wall_collisions
to see the NaNs in invert_mask
. This only shows up if you set the environment variable JAX_DISABLE_JIT=1
, which turns jit
off.
Here is the script that I run, with the environment variable JAX_DISABLE_JIT=1
:
import jax
from jumanji.environments.routing.pac_man import PacMan
if __name__ == "__main__":
jax.disable_jit(True)
seed = 2024
key = jax.random.PRNGKey(seed)
reset_key, key = jax.random.split(key)
env = PacMan()
state, tstep = env.reset(reset_key)
next_state, tstep = env.step(state, 1)
When I train with and without the fix, I get the exact same learning curves (same loss at every step), hinting that the behavior has not changed. I wonder if the NaN behavior depends on the version of JAX? I'm happy to merge this change as it is a cleaner implementation but if the environment behavior has not changed, then we should probably not bump the version. Would you have a way to show that the environment produces NaNs before the fix and not after this change?
Yes, the behavior would be the same, since according to this thread, XLA returns 0 instead of NaN for False * jnp.inf
when things are JIT'ed, which just so happens to be the intended behavior in the code. The issue comes with debugging: when JAX_DISABLE_JIT=1
, you have weird issues with ghosts going through walls. Here's an animated example:
If this doesn't warrant a version bump, then I'm more than happy to change the version back to v0.
Oh that makes complete sense. Since the behavior of the non-jitted environment changes, let's then bump the version.
Thank you for your contribution!
In the PacMan environment, when trying to calculate all the valid actions a ghost could take (in
check_ghost_wall_collisions
in pac_man/utils.py) theinvert_mask * jnp.inf
call was producing an array full of NaN's where invert_mask == 1. This lead to all actions being valid for ghosts.Instead, what this line should be doing is a
jnp.where
call, that conditionally replaces all 1's ininvert_mask
withjnp.inf
.