Closed yswi closed 3 years ago
It is likely a compatability error with the newest version of jax.
Try using: jax = 0.1.73 jaxlib = 0.1.51
In colab run the following in a cell and restart the runtime:
!pip install jax==0.1.73
!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
For researchers using different CUDA and Python versions:
export CUDA="102"
export PYTHON="37"
pip install jax==0.1.73
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda${CUDA}/jaxlib-0.1.51-cp${PYTHON}-none-manylinux2010_x86_64.whl
It is finally working!! @tancik
I put these on top of the colab
!pip install -qq jax==0.1.73
!pip install -qq --upgrade https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.51-cp37-none-manylinux2010_x86_64.whl
!pip install -qq neural_tangents==0.2.2
Note the default colab uses python 3.7 and cuda 11.0. These are reflected in the pre-built wheel's URL.
Your work is really impressive. I visited your git to reproduce your result hoping that it could be the baseline for my research. But there seem to be some issues here. An error "'ShapedArray' object has no attribute 'val'" shows up if I try running '1d_ntk_opt.ipynb' and '1d_regression.ipynb'. Can you help me with this problem?
Thank you