google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Question: difference between kernel_fn from stax.serial and nt.empirical_ntk_fn(apply_fn)? #137

Open jecampagne opened 2 years ago

jecampagne commented 2 years ago

Hello, I'm a newby with your library which is really looks nice indeed and I would like to take benefit of it to make some exercices to illustrate private lecture for some colleagues. So , I would'nt like to make rough mistakes (Notice that I have posted question also in the same spirit). Let me know if there is a forum dedicated to this kind of user exchanges somewhere.

So, after

            #Build the network
            init_fn, apply_fn, kernel_fn = stax.serial(
                stax.Dense(N, W_std=1., parameterization='standard'), 
                stax.Relu(),
                stax.Dense(1, W_std=1., parameterization='standard')
            )

one can do

emp_ntk_kernel_fn = nt.empirical_ntk_fn(apply_fn)

If I am right the emp_ntk_kernel_fnis the finite size NTK kernel based on the Network Architecture, but then what is the difference with the kernel_fn ie the third return argument of stax.serial?

Thanks

romanngg commented 2 years ago

Welcome! This is a good channel - we also have https://github.com/google/neural-tangents/discussions, either place is OK.

emp_ntk_kernel_fn is the finite-width kernel function; it returns the outer product of Jacobians of apply_fn wrt parameters theta, and it depends on specific parameters theta, so you'll call it like emp_ntk_kernel_fn(x1, x2, theta). In short, this kernel describes the behavior of the linearization of apply_fn around theta. This is also \hat Theta from https://arxiv.org/pdf/1902.06720.pdf.

kernel_fn returned by stax is the infinite-width kernel function, namely, for the same architecture, kernel_fn(x1, x2) = plim_{n->infty} empirical_ntk_fn(x1, x2, theta) assuming that theta ~ N(0, 1), i.e. weights are i.i.d. Gaussians. (Minor note, this is assuming parameterization='ntk', 'standard' is slightly different per https://arxiv.org/pdf/2001.07301.pdf.). I.e. the random variable empirical_ntk_fn(x1, x2, theta) (given random normal theta) converges in probability to a constant kernel kernel_fn(x1, x2). In short, this kernel describes the behavior of the infinite ensemble of infinite widths apply_fn networks. This is also Theta from https://arxiv.org/pdf/1902.06720.pdf.

Hope this helps!

jecampagne commented 2 years ago

Great! look at my last post. Thanks