gaudetcj / DeepQuaternionNetworks

Code for Deep Quaternion Networks
MIT License
52 stars 11 forks source link

Loss going to NAN with Batch Normalization #4

Open Phirefly9 opened 5 years ago

Phirefly9 commented 5 years ago

Good morning,

I'm very interested in your work and recreating the results in your paper. I've attempted to run the small 2 block network and am finding the loss goes to NAN after 1 epoch for cifar10, and 3 epochs for cifar100.

I have removed the Quaternion Batch Normalization and found that then the network works fine on cifar10, although the results are lower than published in the paper by ~10%.

I noticed that you had changed the learning rate from the paper so I matched it with the paper for the first epochs and it diverged at 3 epochs on cifar10.

for the record I am running with the following versions: tensorflow-gpu 1.12 from conda keras 2.2.4 installed from source because the version in conda had a bug.

do you have any thoughts what may be causing this behavior?

gaudetcj commented 5 years ago

Good morning,

Glad to hear you are interested. The Nan is an issue a lot of people get when first running and I think it may be a version issue as a potential bug in the BN code.

This repo was lost when the computer all the work on died. This is the backup that was a little behind the actual code. I am currently trying to rewrite all of this in PyTorch and redo the entire repo.

I am working full time so I do not have as much time as I would like to work on it, but keep an eye out and I will update it as soon as I can. I'll have scripts to train on each experiment to reproduce the results from the paper as well as some experiments that did not make it into the original paper.

Phirefly9 commented 5 years ago

I did some experiments and have tracked down the issue to the "scaling" mode of BN. using center only results in the final result (This is the two block model):

1407/1407 [==============================] - 52s 37ms/step - loss: 0.0846 - acc: 0.9959 - val_loss: 0.6796 - val_acc: 0.9190

While using scaling only results in the NAN