google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

Jax v0.2.21 compatibility #125

Closed PythonNut closed 3 years ago

PythonNut commented 3 years ago

After updating to jax v0.2.21, importing neural-tangents gives the following error due to the removal of jax.api:

[redacted]/python3.9/site-packages/neural_tangents/utils/batch.py in <module>
     46 from functools import partial
     47 import warnings
---> 48 from jax.api import device_put, devices
     49 from jax.api import jit
     50 from jax.api import pmap

ModuleNotFoundError: No module named 'jax.api'
romanngg commented 3 years ago

Thanks for pointing this out! We haven't updated our latest release yet, will do soon, but in the meantime please use the NT from github head, ie. do

git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .

or

pip install -q git+https://www.github.com/google/neural-tangents
romanngg commented 3 years ago

Pushed version 0.3.8 to https://pypi.org/project/neural-tangents/, should work with latest JAX.

PythonNut commented 3 years ago

Great! Sorry, I should have tried the latest main first, I forgot my checkout had fallen behind.