Open Shuhul24 opened 2 years ago
I'm afraid not, you would need to write your own stax
layer, defining init_fn
, apply_fn
, and kernel_fn
, e.g. as https://github.com/google/neural-tangents/blob/9f21e6e4f21a279ebbb033ff924e1ebc4723e077/neural_tangents/_src/stax/linear.py#L749
To what extent you'll be able to reuse your existing code will depend on the specifics.
We have tools allowing to implement some layers easier than from scratch, such as pointwise nonlinearities https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Elementwise.html. If your layer is an affine non-parametric transformation (similar to https://neural-tangents.readthedocs.io/en/latest/stax.html#linear-nonparametric), it is also easy to automatically translate in into a stax
layer (something I just didn't get to doing yet). In general, if you could tell us what your layer does, we may be able to help implementing it.
Finally, note that empirical kernels (https://neural-tangents.readthedocs.io/en/latest/empirical.html) work with any JAX functions, and don't require them to be written in stax
.
I have built a custom layer (
KerasLayer
) usingclass
in python (sayclass NewLayer
). Can I use something likestax.NewLayer
for manipulatingneural-tangents
on this custom layer?