Closed PythonNut closed 3 years ago
Thanks for pointing this out! We haven't updated our latest release yet, will do soon, but in the meantime please use the NT from github head, ie. do
git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .
or
pip install -q git+https://www.github.com/google/neural-tangents
Pushed version 0.3.8 to https://pypi.org/project/neural-tangents/, should work with latest JAX.
Great! Sorry, I should have tried the latest main
first, I forgot my checkout had fallen behind.
After updating to
jax
v0.2.21
, importingneural-tangents
gives the following error due to the removal ofjax.api
: