rosinality / glow-pytorch

PyTorch implementation of Glow
MIT License
519 stars 96 forks source link

Is something wrong with the loss? #13

Open zhengkw18 opened 4 years ago

zhengkw18 commented 4 years ago

I use CelebA 64x64 5-bit for training(4 GPUs), about 2 hours later, the loss is as low as 1.1. At the same time, the sampled image has low visual quality. If I'm not mistaken, the final bpd of Glow on 5-bit CelebA in the paper is 1.02, how could the loss be so small while not well trained?

rosinality commented 4 years ago

bpd in the paper is from 256px, test set images, so it should be different from training statistics.

matejgrcic commented 4 years ago

In my opinion, -log(n_bins) * n_pixel should be added per image, not per batch. Consequently, the loss should be calculated as follows:

def calc_loss(log_p, logdet, image_size, n_bins):
    n_pixel = image_size * image_size * 3
    loss = -log(n_bins) * n_pixel
    loss = loss + logdet.mean() + log_p.mean()
    return (
        -loss / (log(2) * n_pixel),
        log_p.mean() / (log(2) * n_pixel),
        logdet.mean() / (log(2) * n_pixel)
    )