YorkUCVIL / Wavelet-Flow

Wavelet Flow: Fast Training of High Resolution Normalizing Flows
MIT License
59 stars 2 forks source link

Per level loss values for 64x64 images #2

Closed A-Vzer closed 3 years ago

A-Vzer commented 3 years ago

I am trying to implement this model in pytorch. However I am unsure whether my bpd values are correct. for example for level 0, it starts at around 6.0, for level 3 around 10 and for level 5 around 2.82. Since training takes days I want to be sure the values are somewhat correct. The total likelihood is then the bpd values of every level added right? Since the values are decreasing very slow with each epoch I am unsure if I can ever reach the results in the paper.

Thanks a lot

A-Vzer commented 3 years ago

hold on I can just clone the repo and check it out myself..

JasonYuJjyu commented 3 years ago

I see that you closed the issue but I've already written this out and it might be helpful for others.

Since you are not using the same code I cannot be exactly certain about what you are doing, but I can give some pointers. The general behaviour is that the flows at the higher resolutions contribute more the the total log likelihood. In fact, almost all the likelihood is given by the highest resolution. The computation for the total likelihood is a sum over all the levels, but in my implementation there is a nuanced difference.

A consequence of the Haar transform is that the low-pass component is actually 2 times the average image. As you repeatedly apply these transforms, the average image increases in intensity as the image shrinks. My implementation of the Haar transformation compensates for this. However, this means you need to adjust the log likelihood contribution at each level by h*w*c*log(0.5)*(highest_level - current_level).

To simplify things, you could just not divide each low-pass component by 2, but you need to be careful about the magnitude of the image intensities. A lazy way to fix this is to slap an Actnorm on the input of the flow and its conditioning and set it to a nice magnitude with DDI. Just remember to divide the image values when you visualize at each scale.

A-Vzer commented 3 years ago

Thank you so much. I am glad you responded anyways. So if I am correct, the correction for the haar LP component is 1 (highest_level - current_level) in bits per dimension? (divide by -(h w c log(2.0)))

JasonYuJjyu commented 3 years ago

Not quite, you just subtract h*w*c*log(2.0)*(highest_level - current_level) from the current log likelihood where h,w,c are those dimensions at the current level. Think of it as just dividing the values of a subset of the dimensions by 2 (highest_level - current_level) times.

See this for my implementation: https://github.com/YorkUCVIL/Wavelet-Flow/blob/0297dcb6f2e556d6f1a59485d1aca17efaa12600/src/models/shared/Multi_scale_flow.py#L76-L77