Open celidos opened 4 years ago
Hey Eduard,
1) To not have a bias term, you could simply set stax.Dense(..., b_std=0.)
when constructing the dense layer.
2) For custom weight init, I suggest adapting the stax.Dense
code https://github.com/google/neural-tangents/blob/780ad0ce22d482bcefd12f4d3390090de7206da5/neural_tangents/stax.py#L469 and replace random.normal
calls with your changes; but watch out for the infinite width correspondence (kernel_fn
), if it matters to you (we use iid Gaussian init since this is where the theory is simplest and most developed).
3) Indeed we don't have a well-documented guide on implementing new layers yet, but I'd say that every stax
layer in https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py is to some extent an example of how add custom layers. Very briefly, a layer is defined by a triple init_fn, apply_fn, kernel_fn
functions that initialize finite-width parameters, forward-prop finite-width activations, and forward-prop infinite-width covariances respectively; you can see their specific signatures by looking at examples, e.g. Identity
as the simplest one: https://github.com/google/neural-tangents/blob/780ad0ce22d482bcefd12f4d3390090de7206da5/neural_tangents/stax.py#L1674
If your layer satisfies this API, then it can be combined with other layers via stax.serial
and stax.parallel
. Finally, layers must have a @layer
decorator and optionally @_supports_masking
, that take care about some boilerplate code.
4) Your code looks good to me, it's hard to say what's wrong without seeing the inference code as well (also, I'm not sure what you mean by training an ntk here, since your code sample is describing a finite width forward prop only, so more context would be helpful to diagnose).
Lmk if this helps!
Thank you!
1) I want to run finite-width net with ntk kernel intentionally (to check out differences with "ideal" case)
2, 3) I'll try it 4) That's how i'm tried to implement without stax (finite net - intentionally). Should this code work?
# ... generating cifar2 in such way that
# x_test == x_train[::-1], y_test == y_train[::-1]
def apply_fn(params, x):
[W1, W2, W3] = params
layer1 = jax.nn.relu(jnp.matmul(x, jnp.transpose(W1)))
layer2 = jax.nn.relu(jnp.matmul(layer1, jnp.transpose(W2)))
layer3 = jnp.matmul(layer2, jnp.transpose(W3)).squeeze(axis=1)
return layer3
W1 = jnp.array(...) # init in some way, i.e., just uniform [-1, 1]
W2 = jnp.array(...)
W3 = jnp.array(...)
params = (W1, W2, W3)
apply_fn = jit(apply_fn)
kernel_fn = nt.empirical_ntk_fn(apply_fn)
k_train_train = kernel_fn(train_x, None, 'ntk', params)
k_test_train = kernel_fn(test_x, train_x, 'ntk', params)
cross_entropy = lambda fx, y_true: -jnp.mean(y_true * jnp.log(jax.experimental.stax.sigmoid(fx)) +\
(1 - y_true) * jnp.log(1.0 - jax.experimental.stax.sigmoid(fx)))
predict_fn = nt.predict.gradient_descent(cross_entropy, k_train_train,
train_y, learning_rate, momentum)
fx_train_0 = apply_fn(params, train_x)
fx_test_0 = apply_fn(params, test_x)
# eval ------------
fx_train_t, fx_test_t = predict_fn(t=1e+6, fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)
# calculating accuracies by fx_train gives 1.0 accuracy
# doing the same for fx_test gives ~ 0.5 accuracy
# with the idea that test x is just shuffled train x
It looks OK to me, but again without seeing the fully-runnable code snippet it's hard to say for sure. A few comments:
1) you could try jax.nn.log_sigmoid
for better numerical stability;
2) random-chance accuracy could be associated with silent broadcasting when computing the loss between targets and outputs (or maybe also inputs - np.matmul
are very permissive in terms of input shapes), or maybe also due to NaNs or infinities somewhere - could you check that your test outputs have the same shape as targets, and that they are numerically reasonable? What if you train for less t
?
Hello!
I want to implement my custom layer (Dense layer without bias term (y = Wx) and with custom weight initialization).
How can I do this? I've looked at https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py, but there is no clear example. What must a layer class contain in order for it to function properly?
I've tried to write something like this (using jax numpy only):
(for 3 layers network) But this code works very strangely. For example, I took some train dataset, and the test dataset is just a shuffled train dataset. And a model drawn only using
jax.numpy
after training ntk kernel gives different quality on the same data samples! What is my error?Maybe there is an easier way to implement this dense layer without bias?