google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Question - custom layer (Dense layer without bias and with custom initialization) #67

Open celidos opened 4 years ago

celidos commented 4 years ago

Hello!

I want to implement my custom layer (Dense layer without bias term (y = Wx) and with custom weight initialization).

How can I do this? I've looked at https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py, but there is no clear example. What must a layer class contain in order for it to function properly?

I've tried to write something like this (using jax numpy only):

def apply_fn(params, x):
    [W1, W2, W3] = params
    layer1 = jax.nn.relu(jnp.matmul(x, jnp.transpose(W1)))
    layer2 = jax.nn.relu(jnp.matmul(layer1, jnp.transpose(W2)))
    layer3 = jnp.matmul(layer2, jnp.transpose(W3)).squeeze(axis=1)
    return layer3

(for 3 layers network) But this code works very strangely. For example, I took some train dataset, and the test dataset is just a shuffled train dataset. And a model drawn only using jax.numpy after training ntk kernel gives different quality on the same data samples! What is my error?

Maybe there is an easier way to implement this dense layer without bias?

romanngg commented 4 years ago

Hey Eduard,

1) To not have a bias term, you could simply set stax.Dense(..., b_std=0.) when constructing the dense layer.

2) For custom weight init, I suggest adapting the stax.Dense code https://github.com/google/neural-tangents/blob/780ad0ce22d482bcefd12f4d3390090de7206da5/neural_tangents/stax.py#L469 and replace random.normal calls with your changes; but watch out for the infinite width correspondence (kernel_fn), if it matters to you (we use iid Gaussian init since this is where the theory is simplest and most developed).

3) Indeed we don't have a well-documented guide on implementing new layers yet, but I'd say that every stax layer in https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py is to some extent an example of how add custom layers. Very briefly, a layer is defined by a triple init_fn, apply_fn, kernel_fn functions that initialize finite-width parameters, forward-prop finite-width activations, and forward-prop infinite-width covariances respectively; you can see their specific signatures by looking at examples, e.g. Identity as the simplest one: https://github.com/google/neural-tangents/blob/780ad0ce22d482bcefd12f4d3390090de7206da5/neural_tangents/stax.py#L1674 If your layer satisfies this API, then it can be combined with other layers via stax.serial and stax.parallel. Finally, layers must have a @layer decorator and optionally @_supports_masking, that take care about some boilerplate code.

4) Your code looks good to me, it's hard to say what's wrong without seeing the inference code as well (also, I'm not sure what you mean by training an ntk here, since your code sample is describing a finite width forward prop only, so more context would be helpful to diagnose).

Lmk if this helps!

celidos commented 4 years ago

Thank you!

1) I want to run finite-width net with ntk kernel intentionally (to check out differences with "ideal" case)

2, 3) I'll try it 4) That's how i'm tried to implement without stax (finite net - intentionally). Should this code work?

# ... generating cifar2 in such way that
# x_test == x_train[::-1], y_test == y_train[::-1]

def apply_fn(params, x):
    [W1, W2, W3] = params
    layer1 = jax.nn.relu(jnp.matmul(x, jnp.transpose(W1)))
    layer2 = jax.nn.relu(jnp.matmul(layer1, jnp.transpose(W2)))
    layer3 = jnp.matmul(layer2, jnp.transpose(W3)).squeeze(axis=1)
    return layer3

W1 = jnp.array(...) # init in some way, i.e., just uniform [-1, 1]
W2 = jnp.array(...)
W3 = jnp.array(...)
params = (W1, W2, W3)

apply_fn = jit(apply_fn)
kernel_fn = nt.empirical_ntk_fn(apply_fn)
k_train_train = kernel_fn(train_x, None, 'ntk', params)
k_test_train = kernel_fn(test_x, train_x, 'ntk', params)

cross_entropy = lambda fx, y_true: -jnp.mean(y_true * jnp.log(jax.experimental.stax.sigmoid(fx)) +\
                                             (1 - y_true) * jnp.log(1.0 - jax.experimental.stax.sigmoid(fx)))

predict_fn = nt.predict.gradient_descent(cross_entropy, k_train_train,
                                         train_y, learning_rate, momentum)

fx_train_0 = apply_fn(params, train_x)
fx_test_0 = apply_fn(params, test_x)

# eval ------------
fx_train_t, fx_test_t = predict_fn(t=1e+6, fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)

# calculating accuracies by fx_train gives 1.0 accuracy
# doing the same for fx_test gives ~ 0.5 accuracy
# with the idea that test x is just shuffled train x
romanngg commented 4 years ago

It looks OK to me, but again without seeing the fully-runnable code snippet it's hard to say for sure. A few comments: 1) you could try jax.nn.log_sigmoid for better numerical stability; 2) random-chance accuracy could be associated with silent broadcasting when computing the loss between targets and outputs (or maybe also inputs - np.matmul are very permissive in terms of input shapes), or maybe also due to NaNs or infinities somewhere - could you check that your test outputs have the same shape as targets, and that they are numerically reasonable? What if you train for less t?