alexgkendall / SegNet-Tutorial

Files for a tutorial to train SegNet for road scenes using the CamVid dataset
http://mi.eng.cam.ac.uk/projects/segnet/tutorial.html
851 stars 518 forks source link

why no batch normalization at the end? #9

Closed etienne87 closed 8 years ago

etienne87 commented 8 years ago

Hello there!

I'm trying your (needless to say) awesome, codebase, and i wonder just 2 things :

Thanks a lot again for your time and answers,

Etienne

alexgkendall commented 8 years ago

Thanks Etienne!

1) I haven't actually thought about that. I was originally using the SegNet model for regression, where BN would not have been appropriate after the final conv layer. I would be interested to hear the results of your experiment, and if this achieves better performance.

2) have a read of our segnet paper, which goes into a more detailed comparison between these two methods. We found sparse upsampling with max pooling indices performed better (perhaps due to guiding it to retain some structure from the input) and was more computationally efficient as you point out.

Alex

etienne87 commented 8 years ago

Hello Alex,

1) sure I will let you know if there is any relevant performance gain.

EDIT 1 : i tried 3 different architectures, original, one with final batch norm & no class reweighting, one with final batch norm & kept the class reweighting. It seems okay to use bn at the end, the training loss is almost better. I think removing the reweighting is pretty bad with BN because it probably produces noisy estimates ...due to small batch size

EDIT 2 : training for 10 K, original architecture gives me 77.96% accuracy, & the additional Batch Normed before SoftMax one gives me 80%. Worst Arch (additional BN but no re-weighting strategy gives only 57% !)

bn_testing

EDIT 3 : for fun i tried adding ReLU in decoder part like DeconvNet. It actually hurts very badly the training...maybe BN regularization is not enough?

derelu_hurt

2) both paper are great! just one questions :

Etienne

alexgkendall commented 8 years ago

1) Cool - interesting results! Is this global accuracy? In my experiments I found the class balancing to produce marginally lower global accuracy but far higher class average accuracy.

2) Yes I guess that should work. At first guess I'd imagine your learning rate is too big? Have you tried making it much smaller?

etienne87 commented 8 years ago

1) yes, i think so, i just averaged (np.count_nonzero(mask==gt)/(w*h)) over testing database. I will re-compute with the over metrics.

2) yes & no : i tried around l_r: 1e-8 but loss plateaus very quickly. I'm wondering if putting iter_size > 1 would help getting more robust gradients (but i guess it will mess up the batch normalization...?)

PS : i over-trained up from 4k to 10k iterations & first convolutional layer ended up looking like from this : segnet_first_layer_4k_iterations

to this : conv1

which looks pretty sparse... do you think weight decay could be a bit too big? or network capacity too high?

alexgkendall commented 8 years ago

Is this with supervised learning on CamVid dataset? I do think the network is still overparameterised (despite being much smaller than others proposed like FCN, deconvnet). However the CamVid dataset is very small..

I found using encoder weights pretrained on imagenet to help. Dropout regularisation in Bayesian SegNet also addresses this.

What are you referring to by iter_size?

etienne87 commented 8 years ago

yep this is the CamVid dataset training, which is a bit small indeed. pretraining the encoder can be a good solution, i will try that as well in addition to autoencoder

For iter_size : weight update / iteration is done for batch_size * iter_size inputs at a time. Each solver iteration is accumulated by running iter_size forward+backprop. So basically it is a feature to help in low memory budget such as training auto-encoder like segnet. Problem is I don't know if it plays well with Batch Normalization (means & variance will not be properly accumulated as they should be).

EDIT : I'm really stupid to not have thought of that earlier, but there are obvious work-arounds if your mini-batch is too big for your RAM : (This gives good results for db like mscoco with similar architecture)

At some point, Batch Norm will no longer produce poor estimates, and you are ready for business!

Sorry to write all this triviality here. Let's close the issue.