google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.25k stars 228 forks source link

Sparsely Connected Layers #9

Open jlindsey15 opened 4 years ago

jlindsey15 commented 4 years ago

Hi! Thanks for this awesome resource. I was wondering if the code supports (or could support with simple extensions) computing the NTK and/or linearization for sparsely connected (non-convolutional) layers. If not, is the obstacle practical or theoretical?

romanngg commented 4 years ago

1) AFAIK nt.linearize should work with any function, so if you have a sparse model of signature f(params, x), it should work. I don't know however if JAX has good sparse matrices support and/or if there are frameworks with sparse layers implement.

2) Re NTK, you can call nt.empirical_kernel_fn or nt.monte_carlo_kernel_fn to compute the empirical NTK (as the outer product of Jacobians) of any function as well, but I'm not sure if there is a meaningful / interesting infinite limit in sparse fully-connected layers. E.g. if you take the infinite limit such that hidden units of the next layer are connected to an infinite number of units in the previous layer, then I imagine the limit should be the same as the dense layer NTK modulo some rescaling. I'm not sure if this answers your question, let me know if you have a specific sparse model in mind!

jlindsey15 commented 4 years ago

Thanks a bunch! I had in mind the case of layers in which the number of inputs per node remains constant as the width goes to infinity -- I'm not sure if/how that changes the limit, or whether it is even defined/computable.

romanngg commented 4 years ago

(Sorry for late reply, was on vacation) Ah I see! If you have a layer that has an infinite number of inputs per node eventually (e.g. a dense layer at the top, after a sequence of sparse, finite-inputs per node intermediary layers; I believe you must have such a layer to have an infinite network with a finite output), than I think a meaningful limit might exist, but computing it analytically may be nontrivial / infeasible (perhaps someone else could chime in on this). For now in this case I'd recommend trying the empirical kernel, and see how well it converges and if it gives interesting performance.

romanngg commented 4 years ago

Minor updates on this:

1) Since a while, NT supports input masking (which is propagated through the network) via the mask_constant keyword argument passed to apply_fn / kernel_fn - 621cc988d4c1a905c5c5258b02b18a19d5daa52d. This may allow to add a different sort of sparsity (sparsity in activations vs sparsity in weights - still not exactly what you're asking).

2) @liutianlin0121 has done some really interesting research on sparsity in finite-width empirical NTK: https://arxiv.org/pdf/2006.08228.pdf - this may be relevant to thinking about sparsity in the infinite width.