Closed chenwydj closed 3 years ago
Hey @chenwydj, thanks for checking out NT!
There are two different ways of answering your question. First, it is possible right now to combine Flax and NT, to compute the empirical neural tangent kernel, or to compute the neural tangent kernel estimated via MC sampling. Basically, the functions nt.empirical_kernel_fn
and nt.monte_carlo_kernel_fn
both accept an (init_fn
, apply_fn
) pair which are more-or-less the equivalent of model.init
and model.apply
in Flax. There are a few small issues of convention that must be hacked around; first, flax initialization function expects examples whereas we call initialization functions with shape information only. Second, we pass an rng key to the model as an rng
keyword argument and so the __call__
function of the flax model must accept an rng
key explicitly or otherwise accept unused keyword arguments.
Here is a colab notebook that I hacked together getting this working using a slight variation on the resnet you posted above.
On the other hand, you might have meant "can you compute the analytic kernel function from a flax network?" This is a significantly more difficult problem. We are working on getting something together that can partially compute these kernels but I don't have a firm timeframe for exactly when we might have something ready for testing. As we make progress here, I'll try to remember to bump this issue.
Just in case, a brief comment on stax.serial
, note that we also have stax.parallel
, and in principle can implement arbitrary topologies, e.g. the resnet in the example https://github.com/google/neural-tangents#infinitely-wideresnet (But this is likely not what you were asking for, so otherwise, +1 to Sam, supporting Flax is a great, but also very ambitious feature)
Thank you very much @sschoenholz and @romanngg for your great answers!
I think I will stick to stax.serial
for now! :)
Currently, most neural tangent examples get
kernel_fn
fromstax.serial
.Is there any more advanced way to get
kernel_fn
from complex models?For example, can we get
kernel_fn
fromflax
'snn.Module
? Like this one (from here)Thank you!