Closed YiDongOuYang closed 2 years ago
Hi, thanks for your question! I don't quite undertand it. In what situation Nan appears in this line?
If it helps, you can find the code for sliced_logsumexp
here. And what the mentioned line does is that we need to get probabilities from logits, so we apply softmax to each categorical feature. And in our case we use logarithms, so we get log_EV = unnormed_logprobs - sliced_logsumexp(unnormed_logprobs, self.offsets)
.
The line is based on this line.
Thank you very much for your swift reply:D
I use "python scripts/pipeline.py --config exp/adult/ddpm_cb_best/config.toml --train --sample" to deploy the experiments. But I got "Step 1000/30000 MLoss: nan GLoss: nan Sum: nan ...".
After some debugging, I found out nan is caused by this line. More specifically, you use -inf to pad unnormed_logprobs, which leads the first column of "lse" is nan after "torch.logcumsumexp" operation in this line.
Hmm, I cannot reproduce this. I suggest to try the other datasets. Try diabetes
since it has no categorical features (just to check if it the problem remains, then logsumexp
is definetely not the reason). And then try churn2
to check if it is an "adult-specific" problem.
And, of course, make sure you installed the repo and downloaded data correctly :)
Note, that we use torch==1.10.1+cu111
.
As for -inf in sliced_logsumexp
, you are right, but it should not affect the correctness of this function. Also, you may print unnormed_probs
to check if it is Nan.
I really appreciate your help!
I solve this issue by upgrade pytorch from 1.7.1 to 1.10. BTW, I have a small question that why we use sliced_logsumexp rather than torch.logsumexp used in Multinomial diffusion model?
Thank you again:D
About sliced_logsumexp
:
So, we need to get probabilities from logits. The problem is we have multiple categorical features (let's say two), so we have to apply softmax independantly, e.g. vector [a1, a2, | b1, b2, b3]
means that we have two cat features with 2 and 3 categories, and we want to apply softmax on [a1,a2]
and [b1, b2, b3]
. We could just iterate and use torch.logsumexp
, but it was pretty slow in our experiments, so we've implemented sliced_logsumexp
without loops.
Very clear! Thank you very much for your explanation:)
Thank you very much for your great work! However, I am a little bit confused about this line, which will return nan and makes the loss_multi loss be nan. Thank you very much for your reply!