ishaanb92 / GeneralizedProbabilisticUNet

PyTorch model for the Generalized Probabilistic U-Net. For more details see: https://www.melba-journal.org/papers/2023:005.html
MIT License
13 stars 6 forks source link

Commented-out code in model.py #3

Closed rodschermitg closed 11 months ago

rodschermitg commented 11 months ago

Dear Ishaan,

Thank you for providing the source code for your interesting work!

I noticed in the file src/probabilistic_unet/model.py, there is a section of code that was commented out. As I understand, the commented-out forward-method would (along with the other commented-out helper methods) implement a conventional forward-call that can be used for both training and testing.

May I ask what was the point of commenting out that section? Compared to the not commented-out methods, it seems like a more straight-forward way to obtain the results and losses from a forward pass. After making some minor adjustments to the commented-out code, I was able to successfully integrate this forward-implementation into a standard training pipeline. Are there any potential downsides to using this version?

Kind regards, Roger

ishaanb92 commented 11 months ago

Hi Roger,

Thanks for your interest in my work! You're right, in that the commented-out section does implement a standard forward() method, and conceptually there is nothing wrong with doing things this way.

Inference

Assuming you have a single forward() method that both computes the U-Net features and performs latent space sampling, the code would look something like this:

for idx in range(N_samples):
   sample[idx] = model.forward()

In this case, each sampling iteration essentially re-computes the U-Net features, which is inefficient.

When the "standard" forward() method is broken into 2 parts i.e.

The code looks as follows:

model.forward()
for idx in range(N_samples):
   sample[idx] = model.sample()

This avoids redundant re-computation of the U-Net features.

Training

During training, the usual practice is to draw a single (latent space) sample, combine it with the U-Net features, and compute the loss. This works well if the KL divergence between the distributions of the prior and posterior latent can be computed analytically. However, if this is not the case (e.g. if the latent space distributions are mixtures of Gaussians), one needs to use a Monte Carlo estimate for the KL-divergence integral. This requires multiple samples from the latent space distributions for a good approximation. In that case, breaking up the forward() method into two parts makes things much more efficient.

Hope this makes things clear!

Cheers, Ishaan

rodschermitg commented 11 months ago

That makes a lot of sense. Thank you for the clarifications!