google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 226 forks source link

Question: Possible to use nt's stax implementation for a slightly less-linear neural net? #26

Open jguhlin opened 4 years ago

jguhlin commented 4 years ago

Hello, I have a (probably basic) question. I was wondering if it is possible to use NT's stax implementation to do a more basic neural net. I'm attempting to embed some continuous sequences into n-dimensional space, where inputs x1 and x2 are run through two dense layers, and the final output of the neural net is the manhattan distance between x1 and x2 after the dense layers. This is just so that embedded representation mimics the manhattan distance between the two continuous sequences.

Sorry if that isn't clear, my model is below:

    input1 = Input(shape=(k,5), dtype='float32', name="k1")
    input2 = Input(shape=(k,5), dtype='float32', name="k2")

    input1_flat = Flatten()(input1)
    input2_flat = Flatten()(input2)

    dense1 = Dense(1024, activation="relu", name="Dense1", use_bias=False)
    dense_out = Dense(dims, activation="linear", name="DenseOut", use_bias=False,)

    k1m = dense_out(dense1(input1_flat))
    k2m = dense_out(dense1(input2_flat))

    subtracted = Subtract()([k1m, k2m])
    abs = tf.math.abs(subtracted)
    output = tf.keras.backend.sum(abs, axis=1)

Because at the chosen sequence length the possible inputs are 5^17, I was hoping/wondering if neural tangent would be a good fit, but I can't quite figure out how to make the neural net work with the inputs/outputs from the colab notebook tutorial.

If it's not possible or not a good idea, I'm definitely open. Just exploring possibilities. If it is possible I'd appreciate some pointers, as I haven't used JAX/Stax before, and not sure how to integrate the Subtract layer or make it work with 2 different layers as inputs. I'll keep futzing around with it too in the meantime.

Cheers, --Joseph

romanngg commented 4 years ago

Hi Joseph, sorry for the very late reply! I think our library is indeed not quite ready for this usecase, but it might be close... Specific thoughts:

1) In terms of feeding input pairs to an nt.stax network, I would suggest concatenating the inputs along a new dimension, so that your inputs are of shape (batch_size, 2, k * 5), and then have the dense1 layer be a 1D convolutional layer with 1024 channels, filter shape (1,). This way you get the [k1m, k2m] outputs of shape (batch_size, 2, 1024) computed with weight sharing. [PS I'm not sure if in your example k stands for batch size or the lengths of each sequence, but I think the general idea remains the same regardless].

2) Now we need to apply the Subtract layer which we indeed do not have now, but since it's an affine transformation on the (batch_size, 2, 1024) inputs, it should be easy to implement, so perhaps we will have this feature soon. Hope I'm not missing anything here.

3) Once we have (2) implemented, you should be able to train finite-width networks built with our library (init_fn, apply_fn) on the L1 loss without issues. However the kernels computed with kernel_fn would only correspond to infinite-width networks trained on L2, MSE loss.

Hope this makes sense, let me know if you have any questions and I will update once we have some news on (2).

jguhlin commented 4 years ago

No worries on the reply! k is the length of the input sequence (working with DNA here), and we use kmers.

I use the weight sharing because I train this model, and then take the weights and freeze them (not trainable) in later models. So I would need to change the layer back to just 1 (or possible even more) for downstream applications.

  1. I was expecting to be told I could simply do the subtraction in a way in the layer, so it's good to know it needs some additional coding.

  2. Totally fine, I'd be interested to see the results with the L2 loss and see if it is worthwhile adoption into my pipeline. I've primarily been using MAE loss, but get similar results with MSE.

jl626 commented 4 years ago

Does neural-tangent support element-wise product?

I encountered a similar problem. In my case, I want to do element-wise product instead of subtraction. something like this


    input1 = Input(shape=(n,3), dtype='float32', name="k1") 
    input2 = Input(shape=(n,3), dtype='float32', name="k2")

    dense1 = Dense(1024, activation="relu", name="Dense1", use_bias=False)
    dense_out = Dense(dims, activation="linear", name="DenseOut", use_bias=False,)

    k1m = dense_out(dense1(input1))
    k2m = dense_out(dense1(input2))

    k = k1m*k2m

    output = tf.keras.backend.sum(k, axis=1)

Basically I want to do element-wise multiplication of two NTK/NNGP covariance matrices (size of (num_batch) X n X n), then apply GlobalAvgPool or GlobalSumPool

romanngg commented 4 years ago

Not yet, and off the top of my head I'm not sure what finite-width operation does the elementwise product of covariances correspond to (do you have something in mind?). If you only want to work in the kernel limit (kernel_fn) though and need this operation, you may consider:

1) Use 1D/2D-CNN as described in (1) to effectively work on the concatenation of input1 and input2, having a single input tensor of shape (batch, 2, n, 3) (or (batch, 2, n * 3)), and [co]variance matrices of shape (batch1[, batch2], 2, 2, n, n) (or (batch1[, batch2], 2, n) if you don't use pooling). [or (batch1[, batch2], 2, 2) / (batch1[, batch2], 2) if you flatten (n, 3) dimensions together - I'm not sure exactly how you interpret applying Dense to inputs of shape (batch_size, n, 3) currently, and whether perhaps you mean 1D-conv there - but in any case, the general idea is to add an additional spatial dimension to concatenate two inputs and have weight sharing between parameters applied to them]

2) Add your own custom layer which would do something like below (very rough sketch, but hopefully conveys the idea - happy to elaborate more if you have questions!) and compute the product of covariance matrices along the concatenation dimension:

@stax.layer
def Prod():
  def init_fn(rng, input_shape):
    return input_shape, ()

  def apply_fn(params, inputs, **kwargs):
    raise NotImplementedError()

  def kernel_fn(k: Kernel):
    def prod(mat, batch_ndim):
      if mat is None or mat.ndim == 0:
        return mat

      if k.diagonal_spatial:
        # Assuming `mat.shape == (N1[, N2], 2, n)`.
        return np.take(mat, 0, batch_ndim) * np.take(mat, 1, batch_ndim)

      # Assuming `mat.shape == (N1[, N2], 2, 2, n, n)`.
      concat_dim = batch_ndim if not k.is_reversed else -1
      return (np.take(np.take(mat, 0, concat_dim), 0, concat_dim) *
              np.take(np.take(mat, 1, concat_dim), 1, concat_dim))

    # Output matrices are `(N1[, N2], n[, n])`.
    return k.replace(nngp=prod(k.nngp, 2),
                     cov1=prod(k.cov1, 1 if k.diagonal_batch else 2),
                     cov2=prod(k.cov2, 1 if k.diagonal_batch else 2),
                     ntk=prod(k.ntk, 2))

  return init_fn, apply_fn, kernel_fn
jl626 commented 4 years ago

Thanks for your help! That's exactly what I want to do.

Perhaps it may not make much sense for the standard neural networks. But NTK (thanks to its covariance form) might be able to use aggregating features from different modalities.

Yes, if I concatenate inputs1 and input2 then, covariance has a form of num_batch x num_batch x 2 x 2 x num_pixels x num_pixels. Your code works perfectly.

Thanks again!

romanngg commented 3 years ago

Some updates:

@jguhlin we have finally added a stax.DotGeneral layer in https://github.com/google/neural-tangents/commit/b582a89600860d331cb91064e3b0075a9e898c89 that allows you to perform arbitrary linear transformations on your inputs, including subtraction. I imagine for your usecase something like this being appropriate:

from jax import random
import jax.numpy as np
from neural_tangents import stax

# Two time series stacked along the second (H) dimension.
x = random.normal(random.PRNGKey(1), (5, 2, 32, 3))  # NHWC

# Subtract second time series from the first one:
nn = stax.serial(
    stax.Conv(128, (1, 3)),
    stax.Relu(),
    stax.DotGeneral(
        rhs=np.array([1., -1.]),
        dimension_numbers=(((1,), (0,)), ((), ()))),  # (5, 30, 128)
    stax.GlobalAvgPool()                              # (5, 128)
)

(see more examples in https://neural-tangents.readthedocs.io/en/latest/neural_tangents.stax.html#neural_tangents.stax.DotGeneral)

@jl626 we have also added a stax.FanInProd layer in https://github.com/google/neural-tangents/commit/484dfa24b95722206134bbc8a2c457988c794822 (note that it's not equivalent to elementwise product of NTKs though)