roggirg / AutoBots

102 stars 23 forks source link

Confusion about training objective #21

Closed ChengkaiYang closed 5 months ago

ChengkaiYang commented 7 months ago

Thanks for your excellent work! In the following objective, the approximating posterior of the latent variable is p_old(z|y,x1:t).The paper has said it can be calculated because the latent variable z is descrete. p_old(z | y,x1:t) = p_old(z | x1:t) * p_old(y | z,x1:t) / p_old(y,x1:t).When calculating the prior p_old(z | x1:t),the origin code wrote this:"priors = modes_pred.detach().cpu().numpy()",but when calculating p_old(y | z,x1:t),I wonder why don't add detach() operator?

image

Thanks for your patience!

ChengkaiYang commented 5 months ago

Sorry, i didn't notice that there has been"with torch.no_grad()".This question is finished!

liuyueChang commented 4 months ago

There is a mistake here, p_old(z | y,x1:t) = p_old(z | x1:t) * p_old(y | z,x1:t) / p_old(y | x1:t)

I have a question about training objective

The loss calculation in the paper includes log p(Z | X) and log p(Y|Z,X), but in the code

AutoBots/utils/train_helpers.py

Line 72 in 3a61ad9

nll_k = nll_pytorch_dist(pred[kk].transpose(0, 1), data, rtn_loss=True) * post_pr[:, kk]

only the p(Y|Z,X) participates in the loss calculation.

Please give me some tips, thank you!

ChengkaiYang commented 3 months ago

There is a mistake here, p_old(z | y,x1:t) = p_old(z | x1:t) * p_old(y | z,x1:t) / p_old(y | x1:t)

I have a question about training objective

The loss calculation in the paper includes log p(Z | X) and log p(Y|Z,X), but in the code

AutoBots/utils/train_helpers.py

Line 72 in 3a61ad9

nll_k = nll_pytorch_dist(pred[kk].transpose(0, 1), data, rtn_loss=True) * post_pr[:, kk]

only the p(Y|Z,X) participates in the loss calculation.

Please give me some tips, thank you! log p(z| x) participates in the loss calculation as a part of KL_loss KL(p_old(z| y, X_1:t) || p(z|x)) but not reconstruction loss. So the CVAE loss has two parts, minimize reconstruction loss (minimize negative loglike loss) and minimize KL loss. Hope can help you.