google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

Mismatch between implementation and documentation #129

Closed mekty2012 closed 3 years ago

mekty2012 commented 3 years ago

Hello! I was trying to implement some networks involving Exp, however it says that there is no stax.Exp(). I was running this in google colab CPU environment, with latest version 0.3.8. But in the stax documentation, I still found Exp function with its implementation. Is it just bug or is there some reason for Exp being removed? And if there is some reason, will using ElementwiseNumerical yield issues?

romanngg commented 3 years ago

Hey, so this is a bug related to https://github.com/google/neural-tangents/issues/127.

Short fix is to install both JAX and NT from head.

!pip install -q git+https://www.github.com/google/jax
!pip install -q git+https://www.github.com/google/neural-tangents

The issue is that the documentation is built from head, which is ahead of our pypi release, which doesn't yet include the new nonlinearities like Exp. We'll need to at least add versions to our docs to avoid this kind of issues. I will also take a look and see if we can release a new pypi release that is still compatible with JAX at head (https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-unreleased introduces a breaking change)

mekty2012 commented 3 years ago

It works perfectly. Thanks a lot.

romanngg commented 3 years ago

Awesome! FYi I just pushed a temporary fix https://github.com/google/neural-tangents/commit/9f2ebc88905c46d60b7c4a9da25636924acc9d45 that makes NT at head compatible with JAX at head and at pypi release, so no need to install JAX from head now.