google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Does neural-tangents work for custom layer? #165

Open Shuhul24 opened 2 years ago

Shuhul24 commented 2 years ago

I have built a custom layer (KerasLayer) using class in python (say class NewLayer). Can I use something like stax.NewLayer for manipulating neural-tangents on this custom layer?

romanngg commented 2 years ago

I'm afraid not, you would need to write your own stax layer, defining init_fn, apply_fn, and kernel_fn, e.g. as https://github.com/google/neural-tangents/blob/9f21e6e4f21a279ebbb033ff924e1ebc4723e077/neural_tangents/_src/stax/linear.py#L749

To what extent you'll be able to reuse your existing code will depend on the specifics.

We have tools allowing to implement some layers easier than from scratch, such as pointwise nonlinearities https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Elementwise.html. If your layer is an affine non-parametric transformation (similar to https://neural-tangents.readthedocs.io/en/latest/stax.html#linear-nonparametric), it is also easy to automatically translate in into a stax layer (something I just didn't get to doing yet). In general, if you could tell us what your layer does, we may be able to help implementing it.

Finally, note that empirical kernels (https://neural-tangents.readthedocs.io/en/latest/empirical.html) work with any JAX functions, and don't require them to be written in stax.