Closed liutianlin0121 closed 5 years ago
Hey! Thanks for checking out the library. Don't hesitate to let us know if you run into issues.
The root of the problems that you're seeing is that we recently just pushed a pretty large change to the library which adds a lot of functionality and does some reorganization. In particular, we now have a bunch of tools for computing analytic GP kernels in addition to the empirical ones that we used to have. Unfortunately, we don't have a great way of letting people know generally about the changes. We did just push a change to github that should have all the colabs / examples working, but let us know if we missed anything.
To get the same behavior as in the previous version of NT the code,
from neural_tangents import tangents
ker_fun = tangents.ntk(f)
predictor = tangents.analytic_mse_predictor(ker_fun(params, x_train, x_train), y_train)
would become
from neural_tangents import predict
from neural_tangents.api import get_ntk_fun_empirical
ker_fun = get_ntk_fun_empirical(f)
predictor = predict.gradient_descent_mse(ker_fun(x_train, x_train, params), y_train)
The layers
library has been deprecated and replaced by neural_tangents.stax
which is a drop in replacement for jax.experimental.stax
. Networks defined using neural_tangents.stax
come with have their infinite-width GP kernels as,
from neural_tangents import stax
from neural_tangents import predict
init_fun, f, ker_fun = stax.serial(stax.Dense(1024), stax.Relu(), stax.Dense(1))
predictor = predict.gradient_descent_mse(ker_fun(x_train, x_train), y_train)
Finally, while the code works, we are still actively working on it and I wouldn't say that the API should be considered stable yet. I think it will stabilize significantly in the next couple weeks (and the changes should be for the better / pretty easy to adapt to). By the same token, if you have any thoughts on design as you use the library please let us know!
thanks a lot!
Thanks a lot for making this repository public!
When running the notebooks
weight_space_linearization.ipynb
andfunction_space_linearization.ipynb
on Google Colab using the link provided on these notebooks, I was unable to import neural_tangents.tangents. A screenshot is attached below:The same problem happens when I was trying to run the repository locally on my computer.
This issue seems to happen since the repository has been updated about a week ago. The old version of the codes (currently on the notebook branch) works fine.