instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
645 stars 80 forks source link

fix: pacman ghost valid action calculations result in NaNs #241

Closed taodav closed 6 months ago

taodav commented 6 months ago

In the PacMan environment, when trying to calculate all the valid actions a ghost could take (in check_ghost_wall_collisions in pac_man/utils.py) the invert_mask * jnp.inf call was producing an array full of NaN's where invert_mask == 1. This lead to all actions being valid for ghosts.

Instead, what this line should be doing is a jnp.where call, that conditionally replaces all 1's in invert_mask with jnp.inf.

CLAassistant commented 6 months ago

CLA assistant check
All committers have signed the CLA.

clement-bonnet commented 6 months ago

Hi, thank you for spotting this bug! If this is true and that this fixes it, we then need to bump the version of PacMan from PacMan-v0 to PacMan-v1. Could you please make the following changes to the registry and documentation?

taodav commented 6 months ago

I've updated my commit to bump the version of PacMan to v1. Let me know if I've missed anything!

clement-bonnet commented 6 months ago

That's perfect, thanks! I'm struggling to reproduce the issue as I am not finding NaNs in different tests I've done. Would you have a small reproduction of the NaNs that you would be able to share? Thank you!

image

taodav commented 6 months ago

I did a bit of digging, and essentially if the function is jited, then jnp.inf * False returns 0 (which works with PacMan), whereas it should return NaN:

https://github.com/google/jax/issues/12233#issuecomment-1238401228

The NaN's don't show up in the action_mask, but essentially zeros out the action_mask for ghosts. I put a breakpoint at the return statement of check_ghost_wall_collisions to see the NaNs in invert_mask. This only shows up if you set the environment variable JAX_DISABLE_JIT=1, which turns jit off.

taodav commented 6 months ago

Here is the script that I run, with the environment variable JAX_DISABLE_JIT=1:

import jax
from jumanji.environments.routing.pac_man import PacMan

if __name__ == "__main__":
    jax.disable_jit(True)

    seed = 2024
    key = jax.random.PRNGKey(seed)
    reset_key, key = jax.random.split(key)

    env = PacMan()

    state, tstep = env.reset(reset_key)

    next_state, tstep = env.step(state, 1)
clement-bonnet commented 6 months ago

When I train with and without the fix, I get the exact same learning curves (same loss at every step), hinting that the behavior has not changed. I wonder if the NaN behavior depends on the version of JAX? I'm happy to merge this change as it is a cleaner implementation but if the environment behavior has not changed, then we should probably not bump the version. Would you have a way to show that the environment produces NaNs before the fix and not after this change?

taodav commented 6 months ago

Yes, the behavior would be the same, since according to this thread, XLA returns 0 instead of NaN for False * jnp.inf when things are JIT'ed, which just so happens to be the intended behavior in the code. The issue comes with debugging: when JAX_DISABLE_JIT=1, you have weird issues with ghosts going through walls. Here's an animated example:

unjit_pacman

If this doesn't warrant a version bump, then I'm more than happy to change the version back to v0.

clement-bonnet commented 6 months ago

Oh that makes complete sense. Since the behavior of the non-jitted environment changes, let's then bump the version.

Thank you for your contribution!