google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 225 forks source link

Question: kernel_fn implementation in Erf() #113

Closed ryuichi0704 closed 3 years ago

ryuichi0704 commented 3 years ago

Hi,

According to the Erf() in stax, I want to confirm the implementation. When we consider 2-layer MLP without training the last layer, the NTK is the covariance matrix of the data multiplied by

スクリーンショット 2021-05-09 0 52 23

ref: Lee et al. 2019, Jiang et al. 2020.

Here, the unit matrix is added to the covariance matrix. However, I cannot find that part in the current stax implementation.


By using neural-tangents, I tried to visualize NTK with the default Erf (a=1.0, b=1.0, c=0.0), but the values seem to diverge when the inner product of the inputs is 0 or 1. The input vector length is normalized to be one. Am I missing something?

Code

def get_kernel(i):
    def rotation_o(u, t):
        R = np.array([[np.cos(t), -np.sin(t)],
                      [np.sin(t),  np.cos(t)]])
        return  np.dot(R, u)

    u = (1, 0)
    Ru = rotation_o(u, i*np.pi/180)

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(2, 100),
        stax.Erf(),
        #stax.Dense(1) # do not train second layer
    )    

    return kernel_fn(jnp.array([u]), jnp.array([Ru]), "ntk")[0][0], np.inner(u, Ru)

kernels_erf = []
inner_products = []
for i in range(360):
    kernel, inner_product = get_kernel(i)
    kernels_erf.append(kernel)
    inner_products.append(inner_product)    
plt.plot(inner_products, kernels_erf) 
plt.title("Erf")
plt.grid(linestyle="dotted")

download

romanngg commented 3 years ago

1) Regarding addition - it's happening in this line: https://github.com/google/neural-tangents/blob/fd1611660c87edcb0c2e50403f691b60d2cc252b/neural_tangents/stax.py#L2652. This ensures that prod in https://github.com/google/neural-tangents/blob/fd1611660c87edcb0c2e50403f691b60d2cc252b/neural_tangents/stax.py#L2667 is the first term under the square root in your formula, i.e. (1 + 2 \Sigma_{x x})(1 + 2 \Sigma_{\hat{x} \hat{x}}).

2) Re divergence - please note that stax.Dense(2, 100) is the same as stax.Dense(width=2, W_std=100), so your weight variance is very high, and I believe in this case it makes sense for NTK to become large with weight variance when x \approx \hat{x}. I.e. in this case NTK = \Sigma * \Tau, where \Sigma = W_std**2 x @ x.T, so quadratic in W_std, but \Tau ~ 1 / W_std (i.e. only inverse-linear with W_std, if x \approx \hat{x}), so their products should be proportional to W_std.

Lmk if this helps!

ryuichi0704 commented 3 years ago

Thank you for your reply. I understand 👍

Anthony2018 commented 2 years ago

Hi, I just wondering if the erf is still nonlinear when calculating the NTK. how we convert a nonlinear activate function to linear? because each time when you calculate the NTK(x,X) you have nonlinear, how your finial closed form is linear?