zincware / ZnNL

A Python package for studying neural learning
Eclipse Public License 2.0
6 stars 1 forks source link

Jitting the NTK function #91

Open SamTov opened 1 year ago

SamTov commented 1 year ago

It looks like we need to do a PMAP check before jitting the NTK function:

0%|                                                                    | 0/500 [00:00<?, ?batch/s]/tikhome/stovey/miniconda3/envs/zincware/lib/python3.10/site-packages/jax/_src/dispatch.py:289: UserWarning: The jitted function ntk_fn includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See https://github.com/google/jax/issues/2926.
  warnings.warn(

We just need to run the following:

devices = jax.devices()

if len(devices) > 1:
    ntk_fn = jax.jit(ntk_fn)
else:
   ntk_fn = ntk_fn