Closed ybj14 closed 4 years ago
Hello, I think it's simple layer mismatch problem. Since MNIST image is 28x28 after 2 operation of stride (2, 2) image reduces to 7x7 instead of 8x8. So you could change the argument of AvgPool
layer to match this, or just use the GlobalAvgPool
layer which you do not need to specify the size. I confirmed that either fix gives non-nans for MNIST.
Thanks very much..Should've checked it more carefully.
Thanks for reporting the issue. For this specific case, we could provide more informative error or warning.
Hi, thanks for your effort on developing this library. However, when I conduct some experiments, the example code of WideResNet outputs NaN on MNIST dataset but seems normal on cifar10 dataset. My test file is like follows:
The running result is
Maybe this is an issue with
jax
library?