juglab / n2v

This is the implementation of Noise2Void training.
Other
387 stars 107 forks source link

why does it work so well (theoretical arguments)? #78

Closed codingS3b closed 3 years ago

codingS3b commented 4 years ago

After reading your paper and being impressed by the results I was curious of how you implemented the blind-spot training in keras, so thanks for sharing your code! From looking at it (and what you also describe in the paper) it seems that you create image patches, randomly select pixel values (that are sufficiently far apart from one another) whose values you change and train a standard U-net on mean square error to predict the original values of the changed pixels, disregarding all pixels that were not masked. I am kind of puzzled on why this works so well. What exactly keeps the network from still learning the identity there? Say, we select a pixel whose (noisy observed) value is 42 and we change its value to 30 through the masking procedure. Now I would assume the network should still learn to output a value very close to 42 in order to minimize the loss. Why is it not doing that and instead even comes up with a noise removed value? Did I miss something in the code that takes care of that or is it some more fundamental part of the idea that I did not get right? I would be glad if you could help me out on that.

psteinb commented 4 years ago

(@tibuch @alex-krull please correct me if I am wrong.)

I always understood the core idea of n2v as based on the mental model of additive stochastic noise on-top of a non-stochastic signal in an image.

Leaving issues with hyperparameters aside for a moment, I always thought that the blind-spot architecture trained with SGD forces the network to estimate the intensity of the blinded pixel by using a neighborhood of i.i.d. intensities. At the heart of this lies the assumption, that mini-batched SGD inherently tries to produce a representation for the mean distribution of values (here pixel intensities) of the training data. If you now return to the model from above, the mean values must origin from the signal as the noise is i.i.d. around the signal intensities.

This concept is also described in the noise2noise paper.

alex-krull commented 4 years ago

Thanks Peter, I agree. I hope the following explanation can help:

It helps to imagine a situation where an infinite amount of training data is available. During training, the network will see the same noisy blind spot input patch with different conflicting target values, depending of the instantiation of the noise. Since we are using the MSE loss, the network will learn to predict the expected value when it is presented with conflicting target values. Note that it cannot predict the original noisy value 42, because this value will be different for each instantiation of the noise and there is no possibility to infer the value from the (masked) input. We assume that the noise is zero centred, which means that the conflicting target values for that pixel are centred around the true signal at the pixel, i.e. the expected value of the noisy values is equal to the expected value of the clean signal at the pixel.

This is indeed the same general approach as in noise2noise. We can use a noisy target, because we know that the expected value of the noisy target is the same as the expected value of the true signal we are interested in. You can view noise2void as noise2noise, with the central pixel removed from the input. Removing the pixel allows us to use it as our target. This is just as good as taking a pixel from a second noise image, as it's done in noise2noise.

Note that the noise does not have to be i.i.d. for each pixel, Poisson noise would be an example that is zero centred but differently distributed in every pixel depending on the underlying signal.

codingS3b commented 4 years ago

Thanks for the explanations and the discussion in yesterdays meeting! I am probably making this more complicated now when trying to throw formulas in, but it would be great to be able to rigorously derive the n2v behaviour. Here is what I understood (and sorry for doing the latex via images, but github seems to not support tex directly): You define the output of the CNN for a given pixel i as

grafik

For training n2v, you use as label the noisy observation x_i and as input the modified receptive field of x_i where only the pixel value of x_i is replaced by some neighborhood pixel: I denote that as

grafik

Then, (in the infinite sample case) the following objective is minimzed:

grafik

(Here, I am already not sure about if the expectation is running over all pixels or only the ones from the receptive field)

From my statistics classes I recall, that the mean squared error loss between two random variables X and Y (where Y is observed and X the random variable which we want to do inference about.) grafik is minimized by the conditional expectation (also a random variable) grafik

However, I struggle to apply that fact to the n2v objective to see that (assuming the network optimization is able to find parameters achieving the minimum), grafik

Once again, sorry for this complicated formulation of the question but I hope I am missing some theoretical arguments which you could point me to.

alex-krull commented 4 years ago

Hi Sebastian, I am sorry for my late reply. I hope the explanation below helps. If not, please let me know where I lose you.

Let me go through your equations one by one and see where are on the same page and where not:

tibuch commented 3 years ago

I will close this discussion for now. Please feel free to reopen it. Another excellent place for this discussion would be forum.image.sc/.