google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

How to mask NN layer by nt.stax.DotGeneral? #80

Open Kangfei opened 4 years ago

Kangfei commented 4 years ago

I want to mask fully connected layer in an NN by a specified mask vector. I define the NN like this, but there is a dimension error "ValueError: Batch or contracting dimension 1 cannot be equal to channel_axis."

mask = np.ones(shape=(512,)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu(), stax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu(), stax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu(), stax.DotGeneral(rhs=mask, dimension_numbers=(((), ()), ((1,), (0,)))), stax.Dense(1) )

I'm very confused about how this operator 'DotGeneral' works and how to set the correct batch/contract dimension for the masking. Hope anyone knows can help and the document can cover more details and examples.

Best, Kangfei

romanngg commented 4 years ago

Your code sample looks good to me for finite-width masking (init_fn, apply_fn should work), but in the infinite width (kernel_fn) activations along the channel_axis are considered infinite, so you can't mask an infinite vector with a finite rhs, and we raise an error. Are you assuming some kind of interpretation of rhs in the infinite width, i.e. that it is tiled infinitely many times, or something like that?

(FYI, we also support masking via a mask_constant argument - see https://github.com/google/neural-tangents/blob/master/examples/imdb.py for an example, but this is still masking along a finite, spatial dimension (not the infinite channel dimension), and it's a bit of a different usecase, namely one where you mask certain tokens in the input and want it to remain masked as it propagates through the network, so it's only tangentially related)

Kangfei commented 4 years ago

Your code sample looks good to me for finite-width masking (init_fn, apply_fn should work), but in the infinite width (kernel_fn) activations along the channel_axis are considered infinite, so you can't mask an infinite vector with a finite rhs, and we raise an error. Are you assuming some kind of interpretation of rhs in the infinite width, i.e. that it is tiled infinitely many times, or something like that?

(FYI, we also support masking via a mask_constant argument - see https://github.com/google/neural-tangents/blob/master/examples/imdb.py for an example, but this is still masking along a finite, spatial dimension (not the infinite channel dimension), and it's a bit of a different usecase, namely one where you mask certain tokens in the input and want it to remain masked as it propagates through the network, so it's only tangentially related)

Thanks for your reply. I was thinking to try an infinite wide MADE (masked autoencoder for density estimation https://arxiv.org/abs/1502.03509), but now I'm not sure it is feasible as your explain. Whether nngp/ntk does not support masked layer naturally or it haven't be implemented yet ?

romanngg commented 4 years ago

Sorry for the delay - I think this should definitely be possible in principle, but indeed we haven't implemented this yet + it really depends on your exact interpretation of masking in the infinite width. For example in the code above, I think if you interpret rhs as just being tiled infinitely many times, then that layer will not do anything interesting, but might just rescale your kernel by a constant. In another interpretation of masking like nt.stax.Dropout (which we have implemented), you will effectively have a random sparse mask that depends on the inputs, and this will result in rescaling the diagonal of the kernel, which is arguably also not particularly interesting behavior.

Now, from glancing at the paper, it looks like those masks are more structured and might indeed give you something interesting, maybe some sort of linear combinations of NNGPs as if they were applied to inputs masked progressively from left to right?... (super vague, I might be totally wrong here) Sadly this is definitely not supported yet, but happy to discuss more if you have a precise limiting behavior in mind!