Closed kim-hyunsu closed 3 years ago
Thanks for the report! I can't seem to reproduce it Colab: https://colab.research.google.com/gist/romanngg/693d46b3c4a89649ab23a37542319eef/https-github-com-google-neural-tangents-issues-112.ipynb
Could you double-check you're using the latest version of NT/JAX, as in the colab above? Also, do you by any chance run this on a machine with multiple GPUs? If so, how many? (there could be some issues with parallel execution, although I'd expect different error messages)
Thank you for the quick reply. I'm using the following versions of the packages: jax 0.2.12 jaxlib 0.1.65+cuda110 neural-tangents 0.3.6 numpy 1.19.4 and using 4 GPUs with CUDA 11.0 and Nvidia driver 450.51.05. Just in case, Python is 3.6 version and the OS is Ubuntu 18.04LTS. The packages seem the latest in my view.
Hm, so the issue seems to be that you have 64-bit precision enabled, but fx_train_0
, fx_test_0
have different types from k_train_train
and k_test_train
. A quick fix should be to cast them all to the same type before creating the predictor function predict_fn
, e.g.
k_test_train = k_test_train.astype(fx_train_0.dtype)
k_train_train = k_train_train.astype(fx_test_0.dtype)
This appears to stem from the fact that for x64 x_train
and x32 params
, apply_fn(params, x)
is x64, but jacobian(apply_fn)(params, x)
is x32. Filed https://github.com/google/jax/issues/6638
IIUC this might also mean that k_train_train
and k_test_train
are computed in low, 32-bit precision, and in this case you might as well disable x64 precision alltogether for faster x32 performance.
I tried each
k_test_train = k_test_train.astype(fx_train_0.dtype)
k_train_train = k_train_train.astype(fx_test_0.dtype)
and
from jax.config import config
config.update("jax_enable_x64", False)
Now both ways work. Thank you for the kind answer. I didn't know 64-bit precision enabled. Is that a default setting?
I don't think so. Perhaps you had it enabled once in an ipython/colab runtime, and then ran this code in that runtime? Restarting the runtime should also reset this flag to False AFAIK.
Ah, it turns out that I enabled it by an environment variable,
export JAX_ENABLE_X64=True
Now everything is understandable. I appreciate your help.
I believe after https://github.com/google/jax/commit/693d2e20cf40e17b567c4a252f37a4d6b9366e5d there should be no further type mismatches like this one - stax networks will initialize weights with the type corresponding to JAX_ENABLE_X64
vs always defaulting to jnp.float32
.
Here's a repro with JAX installed from head where it now works: https://colab.research.google.com/gist/romanngg/38f635cc20ba1ba667d34408728c1512/issue_112_fixed.ipynb
Thanks for noticing this!
I wrote simple codes with monte_carlo_kernel_fn and gradient_descent modules, but it raised an unidentifiable type error even though I've never manipulated any types in the code. Basically, I followed some examples shown in the source codes except for the fact that I used jax.experiment.stax.Tanh to build a two-layer neural network involving a hypertangent activation.
The code I ran was as follows:
The raised error was as follows:
Is there any idea to address this problem?