This PR adds a JAX-based iAMaLGaMversion, which runs on different hardware backends (CPU, GPU, TPU). It leverages anticipated mean shift, adaptive variance scaling and can either iteratively estimate the full covariance or a diagonal version.
It struggles with the harder Brax task and larger parameter search spaces in MNIST. At this point I am not entirely sure if this is due to a lack of scalability of the method or problems related to hyperparameter exploration.
Finally, I ran into problems when trying to execute the experiments for the multi-agent Waterworld task on a 4 GPU machine (see cmdline output below):
File "/cognition/home/RobTLange/anaconda/envs/snippets/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
File "/cognition/home/RobTLange/anaconda/envs/snippets/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3015, in _reshape_shape_rule
if not core.same_shape_sizes(np.shape(operand), new_sizes):
File "/cognition/home/RobTLange/anaconda/envs/snippets/lib/python3.8/site-packages/jax/core.py", line 1519, in same_shape_sizes
return 1 == divide_shape_sizes(s1, s2)
File "/cognition/home/RobTLange/anaconda/envs/snippets/lib/python3.8/site-packages/jax/core.py", line 1516, in divide_shape_sizes
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
File "/cognition/home/RobTLange/anaconda/envs/snippets/lib/python3.8/site-packages/jax/core.py", line 1420, in divide_shape_sizes
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (64, 16) and (16, 64, 16)
evosax
Source Code: https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/full_iamalgam.py