-
Hi,
According to the Erf() in stax, I want to confirm the implementation.
When we consider 2-layer MLP without training the last layer, the NTK is the covariance matrix of the data multiplied by
…
-
Hey, thanks for the great work!
I'm using BatchNorm in my network, but have set the `use_running_average` parameter of BatchNorm layers to true, which means it will not compute any running mean/std…
-
Here's a piece of code that creates a Keras network, uses its weights to initialize an equivalent Stax network, and then runs both of them on the same inputs, comparing the outputs.
```python
impo…
-
### Description of the model to be implemented
In many areas such as physics, it is convenient to have convolutional layers with periodic boundary conditions (e.g. see [netket](https://github.com/n…
-
Hi,
I've been using the neural-tangents library a lot over the past few months, it's been extremely helpful.
I just a had a question about calculating the marginal log-likelihood for NNGPs, wh…
-
Dear All,
I believe that the colab cookbook needs modifications because when I ran the code in colab, I got the message below.
Thank you,
Frank
```
--------------------------------------------…
-
After updating my environment to work with a more recent version of JAX and FLAX, I have noticed that empirical the NTK Gram matrices computed using `nt.batch` applied to `nt.empirical_kernel_fn` are …
-
Hi, thanks a lot for this nice library!
I noticed that enabling the option `do_stabilize` in `ABRelu` leads to the error `TypeError: max requires ndarray or scalar arguments, got at position 0.` a…
-
After updating to `jax` `v0.2.21`, importing `neural-tangents` gives the following error due to the removal of `jax.api`:
```
[redacted]/python3.9/site-packages/neural_tangents/utils/batch.py in
…
-
If I create constant arrays like `x = np.ones((10, 10))`, they, as `_FilledConstants` will not expose `device_buffer` attribute to detect on which device they are residing.