Closed frank-ccc closed 3 years ago
Thanks for pointing this out! This is due to https://github.com/google/jax/commit/623c20105420619d40b13dd6851acfe8dd3417f4 not yet included in the latest JAX release (coming in v0.2.25)
Short-term fix is to add
!pip install -q git+https://www.github.com/google/jax
at the start of the colab.
This should now be fixed, where by default NT (and JAX) is installed from pypi. If you want to use NT at head, then I'd generally recommend to use JAX at head as well for the case of backward-incompatible changes like this. So you would replace
!pip install -q neural-tangents
with
!pip install -q git+https://www.github.com/google/jax
!pip install -q git+https://www.github.com/google/neural-tangents
FYI just pushed a temporary fix https://github.com/google/neural-tangents/commit/9f2ebc88905c46d60b7c4a9da25636924acc9d45 which should make NT at head compatible with JAX both at pypi release and at head, so no changes on your side should be needed. If there are any other issues, please feel free to reopen!
Dear All, I believe that the colab cookbook needs modifications because when I ran the code in colab, I got the message below. Thank you, Frank