-
I'm having fun playing with the Neural Tangents Cookbook.ipynb and I'd like to try extending it to multivariate regression. However, when I changed the output dimension of last layer in `stax.serial`,…
-
I am aware that the default inference implemented is based on Mean squared error (MSE) loss. Is there an implemented example or a way to obtain aleatoric uncertainty instead (either homoscedestic or h…
-
Example:
```python
from jax import jacobian
from jax.config import config
import jax.numpy as jnp
config.update('jax_enable_x64', True)
def f(x, y):
return x * y
x = jnp.ones((), jnp.flo…
-
I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:
```python
import neural_tan…
-
Dear team, great package, I'm very excited to use it.
However, I tried a simple case, and I failed miserably to get a decent performance.
I generate a multi-dimensional dataset with a relativel…
-
hi, i meet a confuse problem
```
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
stax.Dense(512, W_std=1.5, b_std=…
-
Transductive learning is very common, e.g., node classification on Cora, Citeseer, and Pubmed. I intend to analyze the GNN models, e.g., 2-layer GCN, in the NTK regime.
As I have utilize `neural_ta…
-
Hey! Thanks for all the amazing work.
I'm trying to compute the NTK for some data on a shallow version of WideResNet, and I'm encountering a non-PSD matrix which results in `NaN`s in predictions re…
-
I'm looking to use the library to compute the after kernel for a model trained with the FLAX library? I followed this Colab: https://colab.research.google.com/github/google/neural-tangents/blob/main/n…
-
Hi,
I need to compute the empirical NTK kernel (J@J.T) for a NN with ~2.5M parameters including convolution, pooling and dense layers. I need to compute the kernel for up to ~30000 examples of size…