Changing GRU_HIDDEN_DIM and/or FC_DIM_SIZE to any other value than the value of NUM_STEPS in baselines/IPPO/config/ippo_rnn_mpe.yaml or python baselines/MAPPO/mappo_rnn_mpe.py throws an error.
Example output with GRU_HIDDEN_DIM = 64:
python baselines/IPPO/ippo_rnn_mpe.py
Error executing job with overrides: []
Traceback (most recent call last):
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 149, in broadcast_shapes
return _broadcast_shapes_cached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/util.py", line 287, in wrapper
return cached(config.trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/util.py", line 280, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 155, in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 171, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(16, 1), (16, 128), (16, 64)]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 458, in main
out = train_jit(rng)
^^^^^^^^^^^^^^
File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 153, in train
network_params = network.init(_rng, init_hstate, init_x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 72, in __call__
hidden, embedding = ScannedRNN()(hidden, rnn_in)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/flax/core/axes_scan.py", line 148, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/flax/core/axes_scan.py", line 120, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/JaxMARL/baselines/IPPO/ippo_rnn_mpe.py", line 44, in __call__
rnn_state = jnp.where(
^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 1141, in where
return util._where(acondition, if_true, if_false)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 448, in _where
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/satsingh/miniforge3/envs/jaxmarl/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 407, in _broadcast_arrays
result_shape = lax.broadcast_shapes(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Incompatible shapes for broadcasting: shapes=[(16, 1), (16, 128), (16, 64)]
I do not see any reason algorithmically speaking (IPPO/MAPPO) why these should be tied.
This issue still exists post latest updates: https://github.com/FLAIROx/JaxMARL/issues/68
Changing GRU_HIDDEN_DIM and/or FC_DIM_SIZE to any other value than the value of NUM_STEPS in
baselines/IPPO/config/ippo_rnn_mpe.yaml
orpython baselines/MAPPO/mappo_rnn_mpe.py
throws an error.Example output with GRU_HIDDEN_DIM = 64:
I do not see any reason algorithmically speaking (IPPO/MAPPO) why these should be tied.