google / evojax

Apache License 2.0
834 stars 85 forks source link

Removed jit around pmap #7

Closed MaximilienLC closed 2 years ago

MaximilienLC commented 2 years ago

As stated in the docs, there is no need to jit pmaps: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#pmap-and-jit

Rather than being unnecessary, it actually appears to be problematic: UserWarning: The jitted function <unnamed function> includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See [https://github.com/google/jax/issues/2926].