google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 226 forks source link

tuple index out of range error in `stax.Conv` #25

Closed g-benton closed 4 years ago

g-benton commented 4 years ago

Any time I try to call the initialization function on a network containing a convolutional layer I get the same "tuple index out of range error".

Here is a minimum example using one of the code snippets provided in the preprint:

from neural_tangents import stax
from jax import random

key = random.PRNGKey(10)

def ConvolutionalNetwork(depth, W_std=1.0, b_std=0.0):
    layers = []
    for _ in range(depth):
        layers += [stax.Conv(1, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'), stax.Relu()]
    layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)]
    return stax.serial(*layers)

init_fn, apply_fn, kernel_fn = ConvolutionalNetwork(4)

x = random.normal(key, (10, 100))
init_fn(key, x.shape)

The same issue arises using the WideResNet code in the preprint as well, or while using Cifar-10 data. Does anyone have insight on this?

Thanks!

SiuMath commented 4 years ago

I think you may need the inputs to be a 4d (batch, width, height, channels) tensor since you are using convolution.

On Thu, Mar 19, 2020 at 5:19 PM Greg Benton notifications@github.com wrote:

Any time I try to call the initialization function on a network containing a convolutional layer I get the same "tuple index out of range error".

Here is a minimum example using one of the code snippets provided in the preprint:

from neural_tangents import staxfrom jax import random

key = random.PRNGKey(10) def ConvolutionalNetwork(depth, W_std=1.0, bstd=0.0): layers = [] for in range(depth): layers += [stax.Conv(1, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'), stax.Relu()] layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)] return stax.serial(*layers)

init_fn, apply_fn, kernel_fn = ConvolutionalNetwork(4)

x = random.normal(key, (10, 100)) init_fn(key, x.shape)

The same issue arises using the WideResNet code in the preprint as well, or while using Cifar-10 data. Does anyone have insight on this?

Thanks!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/25, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA2W7CLOFOG3SPAO7YLRIKD6TANCNFSM4LPVS6JQ .

g-benton commented 4 years ago

Thanks! Resolved