google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Question about Standard Parameterization #132

Open bwnjnOEI opened 2 years ago

bwnjnOEI commented 2 years ago

Hi, NTKer! I'm a new one. I read On the infinite width limit of neural networks with a standard parameterization and want to find the factor s of improved standard parameterization in the project's code. But I failed and I don't know where it is in the code. Looks like between lines 1088 and 1105 in stax.py, but I don't see the factor s, please help me!

romanngg commented 2 years ago

IIUC in this parameterization s is taken to infinity to produce the NTK (and is 1 for finite networks), but the parameter that alters the result is the base network width N; in the codebase it's referred to as input_shape[channel_axis] / k.shape1[channel_axis] / inputs.shape[channel_axis] etc. Lmk if this helps!

bwnjnOEI commented 2 years ago

IIUC in this parameterization s is taken to infinity to produce the NTK (and is 1 for finite networks), but the parameter that alters the result is the base network width N; in the codebase it's referred to as input_shape[channel_axis] / k.shape1[channel_axis] / inputs.shape[channel_axis] etc. Lmk if this helps!

Thanks a lot! According to what you said, my understanding is as follows: The improved version of standard parameterization requires multiplying the weight, which initialization is kept standard i.e., 1-fan_in, by 1/sqrt{width}?

It's referred to as the code:

def standard_init_fn(rng, input_shape):
        output_shape, (W, b) = ntk_init_fn(rng, input_shape)
        return output_shape, (W * W_std / np.sqrt(input_shape[channel_axis]), 
                    b * b_std if b is not None else None)
romanngg commented 2 years ago

Sorry didn't notice the question last time - it's correct, although to clarify a bit, here width = fan_in = N, and in the finite-width realm, everything is the same for "Standard (naive)" and "Standard (improved)" (but is different from NTK), just regular NN weights. What is changed is the way you define the infinite-width limit, and, respectively, what we return in kernel_fn. If you define it as the limit of networks of diminishing variance 1 / width, then NTK diverges (and NNGP is unchanged). But if you define it as the limit of networks with layers defined as column 3 in Table 1/2 (https://arxiv.org/pdf/2001.07301.pdf) with s taken to infinity, you will get a well-defined NNGP (same in all parameterizations) and NTK (different). Notably, the NTK will depend on individual widths of N in each layer, so you can consider taking the infinite-width limit where relative layer widths are different.

romanngg commented 2 years ago

FYI, in https://github.com/google/neural-tangents/commit/239cc849cf55d672018bce0e3539e56b1a50870f we have exposed the s parameter as well (along with N, which corresponds to out_dim / out_chan).