Closed g-benton closed 4 years ago
I think you may need the inputs to be a 4d (batch, width, height, channels) tensor since you are using convolution.
On Thu, Mar 19, 2020 at 5:19 PM Greg Benton notifications@github.com wrote:
Any time I try to call the initialization function on a network containing a convolutional layer I get the same "tuple index out of range error".
Here is a minimum example using one of the code snippets provided in the preprint:
from neural_tangents import staxfrom jax import random
key = random.PRNGKey(10) def ConvolutionalNetwork(depth, W_std=1.0, bstd=0.0): layers = [] for in range(depth): layers += [stax.Conv(1, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'), stax.Relu()] layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)] return stax.serial(*layers)
init_fn, apply_fn, kernel_fn = ConvolutionalNetwork(4)
x = random.normal(key, (10, 100)) init_fn(key, x.shape)
The same issue arises using the WideResNet code in the preprint as well, or while using Cifar-10 data. Does anyone have insight on this?
Thanks!
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/25, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA2W7CLOFOG3SPAO7YLRIKD6TANCNFSM4LPVS6JQ .
Thanks! Resolved
Any time I try to call the initialization function on a network containing a convolutional layer I get the same "tuple index out of range error".
Here is a minimum example using one of the code snippets provided in the preprint:
The same issue arises using the WideResNet code in the preprint as well, or while using Cifar-10 data. Does anyone have insight on this?
Thanks!