faustomilletari / VNet

GNU General Public License v3.0
284 stars 123 forks source link

Dice function with batch_size>1 #16

Open RSly opened 7 years ago

RSly commented 7 years ago

Hi,

I tried to correct the dice function to work for batch_size >1, but didn't much succeed... in particular in the backward function... any chance you could consider to update you implementation for batch_size>1 ?

That can really help :) Thanks

RSly commented 7 years ago

it seems it is solved by re-writing top[0].data[0]=np.sum(dice)/float(bottom[0].data.shape[0])

I still have some problems... I will investigate and let you know

RSly commented 7 years ago

here is more details regarding the batch_size problem: figure 1. with batch_size : 1 the network learns nicely, and achieves 90% accuracy

image

figure 2. with batch_size : 8, it stays at accuracy of 20% even after many epochs image

gattia commented 7 years ago

Have you tried changing learning rate when you change batch size? This conversation says that this paper indicates that learning rate should be changed. If you are trying to increase batch size it indicates that learning rate should be decreased. Their relationship indicates if you are doubling the batch size that the learning rate should decrease to ~0.7 of the original learning rate. This likely wont fix everything you are indicating but it might help.

I also thought there was another comment in the Vnet issues that indicates that vnet worked with >1 batch size out of the box. That issue indicates that the dice would be reported as the sum of the 2 volumes in the batch, if Im not mistaken this could result in dice scores upto 2.0 but it shouldnt make a difference, essentially you can divide by 2 to get the average dice of the 2 volumes.

RSly commented 7 years ago

@getta, thanks a lot for the suggestion and the links. I tried the 0.7 decrease. it actually helps, so using this trick, the results with batch_size:1 and batch_size 10 are now Comparable! nice!

however, I was hoping that having a greater batch size to work as a data balancing trick since I have a very unbalanced data (90% class A, 10% class B). but I still don't get better results using batch_size >1 ...

=> regarding the Vnet dice out-of-the-box, it is true that it can work but when batch_size >1 then this dice wont be normalized, so it can not be used along with other normalized loss functions together. this is the main reason it is a good idea to normalize it by the bacth_size, so the dice-loss stays <1 using the following: top[0].data[0]=np.sum(dice)/float(bottom[0].data.shape[0])

gattia commented 7 years ago

Glad it helped!

sagarhukkire commented 7 years ago

Hi @gattia @RSly

I am facing issue with deterministic result of loss values. For example if I run 10 iteration I get 10 loss values. If I rerun model then I get different 10 loss values? Is it normal or do you know what is workaround.

Thanks in advance sagar