Closed gortizji closed 2 years ago
Thanks for reaching out! Unfortunately I wasn't able to reproduce the issue you're having. Here's a colab notebook (note: I had to make the model slightly smaller to fit in public colab GPU memory).
All of the versions seem similar to your current environment. However, I notice your gpu drivers seem a little out-of-date (version 10.2 instead of version 11.2) is it possible that is the issue?
Btw also consider trying implementation=1
, we often find it faster than 2
for convolutions, especially if you only have one output logit (2
scales better with number of logits, but this shouldn't matter if you only have one).
Thanks for the quick answers! Curiously, it seems that using implementation=1
does make things more stable, i.e., error of around 1e-8
. Still for small batches of around 10
the error climbs up to 1e-5
, but it is definitely not 0.1
.
On the other hand, I can also not reproduce this strange behaviour on the Google colab. I will try to update my cuda version, and get back to you.
After updating to CUDA-11.4, I can confirm that the issue indeed only happens on the old CUDA version. With this new version, both implementation=1
and implementation=2
yield an error of the order of 1e-8
, regardless of the batch_size
.
In fact, it seems that support for CUDA-10.2 will fade in the next release of JAX. Even if this happens soon, I would still recommend directly specifying CUDA-11.x as a dependency of neural-tangents
. I do not know what was the root cause for that very strange behaviour, but I am worried it might silently break other functionalities of the library when using it with the wrong CUDA version.
In any case, thank you very much for all your help! You were really helpful.
Thanks a lot for figuring this out! Just pushed a release (https://github.com/google/neural-tangents/releases/tag/v0.3.9) bumping up our minimum JAX version to 0.2.25, which itself should only work with CUDA-11 and higher, so hopefully this should be fixed! Please feel free to re-open if the issue remains
After updating my environment to work with a more recent version of JAX and FLAX, I have noticed that empirical the NTK Gram matrices computed using
nt.batch
applied tont.empirical_kernel_fn
are significantly different depending on the batch size.The code to reproduce this error is:
Surprisingly, if I run this with my old environment I get an average error of the order of
1e-8
, while with the new environment the error is of the order of1e-1
. Also, this error remains exactly the same as long asbatch_size<data.shape[0]
.My old enviornment consisted of:
and my new environment has: