FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
414 stars 72 forks source link

GRU_HIDDEN_DIM/FC_DIM_SIZE and NUM_STEPS still tied to each other #80

Closed satpreetsingh closed 6 months ago

satpreetsingh commented 6 months ago

This issue still exists post latest updates: https://github.com/FLAIROx/JaxMARL/issues/68

Changing GRU_HIDDEN_DIM and/or FC_DIM_SIZE to any other value than the value of NUM_STEPS in baselines/IPPO/config/ippo_rnn_mpe.yaml or python baselines/MAPPO/mappo_rnn_mpe.py throws an error.

Example output with GRU_HIDDEN_DIM = 64:

python baselines/IPPO/ippo_rnn_mpe.py 

Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 149, in broadcast_shapes
    return _broadcast_shapes_cached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/util.py", line 287, in wrapper
    return cached(config.trace_context(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/util.py", line 280, in cached
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 155, in _broadcast_shapes_cached
    return _broadcast_shapes_uncached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 171, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(16, 1), (16, 128), (16, 64)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 458, in main
    out = train_jit(rng)
          ^^^^^^^^^^^^^^
  File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 153, in train
    network_params = network.init(_rng, init_hstate, init_x)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 72, in __call__
    hidden, embedding = ScannedRNN()(hidden, rnn_in)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/flax/core/axes_scan.py", line 148, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/flax/core/axes_scan.py", line 120, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 44, in __call__
    rnn_state = jnp.where(
                ^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 1141, in where
    return util._where(acondition, if_true, if_false)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 448, in _where
    condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 407, in _broadcast_arrays
    result_shape = lax.broadcast_shapes(*shapes)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Incompatible shapes for broadcasting: shapes=[(16, 1), (16, 128), (16, 64)]

I do not see any reason algorithmically speaking (IPPO/MAPPO) why these should be tied.

amacrutherford commented 6 months ago

Ah cheers for spotting, see #82