-
From the demo and the wiki the input data to the Aggregate layer is 3-channel pictures. But I would like to know if the input data is (BsxNxd) and the Adjacent matrix is (BsxNxN), how to use this Aggr…
-
Hey!
I was reading through the code and I noticed that you're using element-wise exponential matrix here:
https://github.com/google/neural-tangents/blob/5f286b7696364217aa4a2d92378aabd0203a791e/n…
-
Hi,
I'm not sure this is the write place to ask this kind of question but I see the issues are filled with questions so hope this is OK. Let me know if I should post this on SO or something (but I …
-
Hey, I would like to calculate the mentioned jacobians. Right now I'm trying this:
```python
func, params, buffers = make_functional_with_buffers(model)
J = jacrev(lambda p: func(p, buffers, inpu…
-
I want to mask fully connected layer in an NN by a specified mask vector. I define the NN like this, but there is a dimension error "ValueError: Batch or contracting dimension 1 cannot be equal to `c…
-
Hello,
I will give a snippet
```python
import jax
import jax.numpy as jnp
from jax import jit
from jax import grad
from jax.example_libraries import optimizers
from jax.config import con…
-
I am trying to improve up this paper https://arxiv.org/pdf/2011.00050.pdf where they optimize some subset using NTK. They optimize their loss in batches. Smaller batches for more complex architectures…
-
For a model being used for classification with `k` classes, for `n` datapoints, the NTK should be of the size `nk` X `nk`. How would we get that with neural-tangents?
Currently, I'm able to get a `…
-
We want to have some more tests cases. Here's some cool things that work with JAX that would be interesting to port over to functorch.
- [ ] [Neural tangents](https://github.com/google/neural-tange…
-
Hi all,
I would like to (analytically) compute the evolution of the weights under the linearized dynamics (i.e., Eqn. (8) in https://arxiv.org/pdf/1902.06720.pdf) and use the resulting weights afte…