It looks like we need to do a PMAP check before jitting the NTK function:
0%| | 0/500 [00:00<?, ?batch/s]/tikhome/stovey/miniconda3/envs/zincware/lib/python3.10/site-packages/jax/_src/dispatch.py:289: UserWarning: The jitted function ntk_fn includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See https://github.com/google/jax/issues/2926.
warnings.warn(
It looks like we need to do a PMAP check before jitting the NTK function:
We just need to run the following: