google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 226 forks source link

Shared weights in Neural Nets #84

Open jiahai-feng opened 3 years ago

jiahai-feng commented 3 years ago

I'm wondering to what extent shared weights are supported (such as in RNNs). I understand that I can hack something together by manually composing apply_fn together using shared params. However, is there a way to compute the kernel when there are shared weights?

sschoenholz commented 3 years ago

Great question! Unfortunately, at the moment we don't have a mechanism for weight sharing. Right now, the best you can do is, as you describe, use the kernel function compute an "untied" weight approximation of the original architecture. Sometimes, for example in the case of vanilla RNNs this approximation is very good is very good whereas for other architectures (such as LSTMs) the GP limit is relatively intractable.

SiuMath commented 3 years ago

Hi, Here is one example using the library to compute kernels for a simple RNN. https://colab.sandbox.google.com/gist/SiuMath/41174917df2b359bc0e6c3bf7ff3990d/rnn-combinator-example.ipynb The RNN_combinator takes an RNN_CELL and outputs the recurrent version of it. This is a preliminary version and only supports simple RNN cells. Note that, in the current version the first axis of the inputs is the time dimension rather than batch dimension.

jiahai-feng commented 3 years ago

Thanks a lot for getting back so quickly! I'm wondering if you could kindly point me to a reference explaining how the kernel function in the RNN snippet works. My current intuition for RNN NTKs comes from reading Greg Yang's paper on his NETSOR programs, and it seems like I should expect a lot of cross terms between vectors corresponding to different time steps. However, the snippet provided above looks so clean that I feel that I must be missing something.

RuihongQiu commented 2 years ago

A quick question. Thanks for the RNN implementation. But seems like it does not support sequences with different lengths. Is it possible to extend the implementation?