-
https://colab.research.google.com/drive/1VHzY55vHtMPsvXR302WoYYAJTj74jy1S?usp=sharing
```
from neural_tangents import stax
from jax import random
init_fn, apply_fn, kernel_fn = stax.Dense(out_di…
-
Thanks for the awesome library! It would be great if PyTorch could support forward-mode automatic differentiation. The main use case is to compute a Jacobian-vector product. I tried using [this trick]…
-
https://colab.research.google.com/drive/1VHzY55vHtMPsvXR302WoYYAJTj74jy1S?usp=sharing
```
from neural_tangents import stax
from jax import random
init_fn, apply_fn, kernel_fn = stax.Dense(out_di…
-
I am truly confused about the difference between ```gradient_descent_mse``` and ```gradient_descent_mse_ensemble```.
In the original NNGP and NTK paper, the author mentioned that we can use Gaussi…
-
Suppose I have a regular old neural network with its weights set to some values. Then the NTK k(x, y) is well-defined as the dot product of df/dw at each input, that is, the dot product of the gradien…
-
I want to use neural_tangents.predict.gp_inference to predict a distribution (mean and variance) on a test set. The documentation says that it returns a function `predict_fn(get, k_test_train, nngp_te…
-
Hi,
I currently use the neural tangents to compute the kernel for CiFAR-10 images. I need to compute the kernel matrix for 10000 images x 10000 images and there are 3x32x32 pixels each image. …
-
Hi,
Really neat lib! I was wondering if it is possible to compute the gradient of the kernel with respect to the input data using autograd? I have had several unsuccessful attempts and was wonderi…
-
Hello,
I have encountered a problem when I try to get the gradient of some loss function with respect to some input variable x, I got a NaN after several iterations. And this only appears when the N…
-
Currently, most neural tangent examples get `kernel_fn` from `stax.serial`.
Is there any more advanced way to get `kernel_fn` from complex models?
For example, can we get `kernel_fn` from `flax`…