Closed rodschermitg closed 11 months ago
Hi Roger,
Thanks for your interest in my work! You're right, in that the commented-out section does implement a standard forward() method, and conceptually there is nothing wrong with doing things this way.
Assuming you have a single forward() method that both computes the U-Net features and performs latent space sampling, the code would look something like this:
for idx in range(N_samples):
sample[idx] = model.forward()
In this case, each sampling iteration essentially re-computes the U-Net features, which is inefficient.
When the "standard" forward() method is broken into 2 parts i.e.
The code looks as follows:
model.forward()
for idx in range(N_samples):
sample[idx] = model.sample()
This avoids redundant re-computation of the U-Net features.
During training, the usual practice is to draw a single (latent space) sample, combine it with the U-Net features, and compute the loss. This works well if the KL divergence between the distributions of the prior and posterior latent can be computed analytically. However, if this is not the case (e.g. if the latent space distributions are mixtures of Gaussians), one needs to use a Monte Carlo estimate for the KL-divergence integral. This requires multiple samples from the latent space distributions for a good approximation. In that case, breaking up the forward() method into two parts makes things much more efficient.
Hope this makes things clear!
Cheers, Ishaan
That makes a lot of sense. Thank you for the clarifications!
Dear Ishaan,
Thank you for providing the source code for your interesting work!
I noticed in the file
src/probabilistic_unet/model.py
, there is a section of code that was commented out. As I understand, the commented-outforward
-method would (along with the other commented-out helper methods) implement a conventionalforward
-call that can be used for both training and testing.May I ask what was the point of commenting out that section? Compared to the not commented-out methods, it seems like a more straight-forward way to obtain the results and losses from a forward pass. After making some minor adjustments to the commented-out code, I was able to successfully integrate this
forward
-implementation into a standard training pipeline. Are there any potential downsides to using this version?Kind regards, Roger