pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Question on loss calculation in CEVAE module. #3236

Closed sujeongsong closed 1 year ago

sujeongsong commented 1 year ago

https://github.com/pyro-ppl/pyro/blob/7e3d62ecf5ebf2fd3798d93e80cb554e7d38f61e/pyro/contrib/cevae/__init__.py#L592

I am totally new here, pyro, and I have a question in CEVAE module. During the loss calculation using svi.step() with the line below,

loss = svi.step(x, t, y, size=len(dataset)) / len(dataset)

(x, t, y) are minibatch of batch_size pre-defined,
why is the size given as len(dataset), and why is the value divided by len(dataset)? len(dataset) is the total number of samples in the training set, should the size and the denominator be the size of a minibatch instead? not len(dataset)?

please correct me if I know wrong. Thank you in advance!

fehiepsi commented 1 year ago

Here size is the full data size. By default, Pyro scales log prob of a batch by full_size / batch_size to get an MC estimate for log prob of full data. We divide by the full_size to get an estimate for average of log prob of full data. This way, the loss scale will likely be agnostic to the full data size.

sujeongsong commented 1 year ago

Here size is the full data size. By default, Pyro scales log prob of a batch by full_size / batch_size to get an MC estimate for log prob of full data. We divide by the full_size to get an estimate for average of log prob of full data. This way, the loss scale will likely be agnostic to the full data size.

Oh, I see.. Then what you said follows exactly same mechanism as the code below?

for epoch in range(num_epochs):
    loss_per_epoch = 0
    for x, t, y in dataloader:
        loss = svi.step(x, t, y)
        assert not torch_isnan(loss)
        loss_per_epoch += loss

    normalizer_train = len(dataset)
    loss_per_epoch_mean = loss_per_epoch / normalizer_train
fehiepsi commented 1 year ago

I think so.

sujeongsong commented 1 year ago

Thank you very much, it helps a lot!