google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

colab cookbook issue #127

Closed frank-ccc closed 3 years ago

frank-ccc commented 3 years ago

Dear All, I believe that the colab cookbook needs modifications because when I ran the code in colab, I got the message below. Thank you, Frank

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-2-a55d78ef1671> in <module>()
      7 import functools
      8 
----> 9 import neural_tangents as nt
     10 from neural_tangents import stax

1 frames
/usr/local/lib/python3.7/dist-packages/neural_tangents/stax.py in <module>()
     79 from jax import random
     80 from jax import ShapeDtypeStruct, eval_shape, grad, ShapedArray, vmap, custom_jvp
---> 81 import jax.example_libraries.stax as ostax
     82 from jax.lib import xla_bridge
     83 from jax.scipy.special import erf

ModuleNotFoundError: No module named 'jax.example_libraries'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
romanngg commented 3 years ago

Thanks for pointing this out! This is due to https://github.com/google/jax/commit/623c20105420619d40b13dd6851acfe8dd3417f4 not yet included in the latest JAX release (coming in v0.2.25)

Short-term fix is to add

!pip install -q git+https://www.github.com/google/jax

at the start of the colab.

romanngg commented 3 years ago

This should now be fixed, where by default NT (and JAX) is installed from pypi. If you want to use NT at head, then I'd generally recommend to use JAX at head as well for the case of backward-incompatible changes like this. So you would replace

!pip install -q neural-tangents

with

!pip install -q git+https://www.github.com/google/jax
!pip install -q git+https://www.github.com/google/neural-tangents
romanngg commented 3 years ago

FYI just pushed a temporary fix https://github.com/google/neural-tangents/commit/9f2ebc88905c46d60b7c4a9da25636924acc9d45 which should make NT at head compatible with JAX both at pypi release and at head, so no changes on your side should be needed. If there are any other issues, please feel free to reopen!