Closed mekty2012 closed 3 years ago
Hey, so this is a bug related to https://github.com/google/neural-tangents/issues/127.
Short fix is to install both JAX and NT from head.
!pip install -q git+https://www.github.com/google/jax
!pip install -q git+https://www.github.com/google/neural-tangents
The issue is that the documentation is built from head, which is ahead of our pypi release, which doesn't yet include the new nonlinearities like Exp
. We'll need to at least add versions to our docs to avoid this kind of issues. I will also take a look and see if we can release a new pypi release that is still compatible with JAX at head (https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-unreleased introduces a breaking change)
It works perfectly. Thanks a lot.
Awesome! FYi I just pushed a temporary fix https://github.com/google/neural-tangents/commit/9f2ebc88905c46d60b7c4a9da25636924acc9d45 that makes NT at head compatible with JAX at head and at pypi release, so no need to install JAX from head now.
Hello! I was trying to implement some networks involving Exp, however it says that there is no stax.Exp(). I was running this in google colab CPU environment, with latest version 0.3.8. But in the stax documentation, I still found Exp function with its implementation. Is it just bug or is there some reason for Exp being removed? And if there is some reason, will using ElementwiseNumerical yield issues?