google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

WideResNet outputs NaN on MNIST dataset #71

Closed ybj14 closed 4 years ago

ybj14 commented 4 years ago

Hi, thanks for your effort on developing this library. However, when I conduct some experiments, the example code of WideResNet outputs NaN on MNIST dataset but seems normal on cifar10 dataset. My test file is like follows:

from neural_tangents import stax
from examples import datasets
from jax import random

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    Main = stax.serial(
        stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
        stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
    Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2),
                       stax.parallel(Main, Shortcut),
                       stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
    blocks = []
    blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
    for _ in range(n - 1):
        blocks += [WideResnetBlock(channels, (1, 1))]
    return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
    return stax.serial(
        stax.Conv(16, (3, 3), padding='SAME'),
        WideResnetGroup(block_size, int(16 * k)),
        WideResnetGroup(block_size, int(32 * k), (2, 2)),
        WideResnetGroup(block_size, int(64 * k), (2, 2)),
        stax.AvgPool((8, 8)),
        stax.Flatten(),
        stax.Dense(num_classes, 1., 0.))

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

x_train, y_train, x_test, y_test = datasets.get_dataset("cifar10", 10, 1)
input_shape = (-1, 32, 32, 3)
x_train = x_train.reshape(input_shape)
x_test = x_test.reshape(input_shape)

key = random.PRNGKey(1)
key, net_key = random.split(key)
output_shape, params = init_fn(net_key, input_shape)
fx_test = apply_fn(params, x_test)
print(fx_test)

x_train, y_train, x_test, y_test = datasets.get_dataset("mnist", 10, 1)
input_shape = (-1, 28, 28, 1)
x_train = x_train.reshape(input_shape)
x_test = x_test.reshape(input_shape)

key = random.PRNGKey(1)
key, net_key = random.split(key)
output_shape, params = init_fn(net_key, input_shape)
fx_test = apply_fn(params, x_test)
print(fx_test)

The running result is

[[-1.0211153   1.2753677   1.1458696  -0.851678   -0.86049545 -1.1933788
   0.6119092   0.24371225  2.8294983  -0.65209687]]
[[nan nan nan nan nan nan nan nan nan nan]]

Maybe this is an issue with jax library?

jaehlee commented 4 years ago

Hello, I think it's simple layer mismatch problem. Since MNIST image is 28x28 after 2 operation of stride (2, 2) image reduces to 7x7 instead of 8x8. So you could change the argument of AvgPool layer to match this, or just use the GlobalAvgPool layer which you do not need to specify the size. I confirmed that either fix gives non-nans for MNIST.

ybj14 commented 4 years ago

Thanks very much..Should've checked it more carefully.

jaehlee commented 4 years ago

Thanks for reporting the issue. For this specific case, we could provide more informative error or warning.