yandex-research / tab-ddpm

[ICML 2023] The official implementation of the paper "TabDDPM: Modelling Tabular Data with Diffusion Models"
https://arxiv.org/abs/2209.15421
MIT License
397 stars 89 forks source link

Nan for traininig tabular diffusion model #2

Closed YiDongOuYang closed 2 years ago

YiDongOuYang commented 2 years ago

Thank you very much for your great work! However, I am a little bit confused about this line, which will return nan and makes the loss_multi loss be nan. Thank you very much for your reply!

rotot0 commented 2 years ago

Hi, thanks for your question! I don't quite undertand it. In what situation Nan appears in this line?

If it helps, you can find the code for sliced_logsumexp here. And what the mentioned line does is that we need to get probabilities from logits, so we apply softmax to each categorical feature. And in our case we use logarithms, so we get log_EV = unnormed_logprobs - sliced_logsumexp(unnormed_logprobs, self.offsets).

The line is based on this line.

YiDongOuYang commented 2 years ago

Thank you very much for your swift reply:D

I use "python scripts/pipeline.py --config exp/adult/ddpm_cb_best/config.toml --train --sample" to deploy the experiments. But I got "Step 1000/30000 MLoss: nan GLoss: nan Sum: nan ...".

After some debugging, I found out nan is caused by this line. More specifically, you use -inf to pad unnormed_logprobs, which leads the first column of "lse" is nan after "torch.logcumsumexp" operation in this line.

rotot0 commented 2 years ago

Hmm, I cannot reproduce this. I suggest to try the other datasets. Try diabetes since it has no categorical features (just to check if it the problem remains, then logsumexp is definetely not the reason). And then try churn2 to check if it is an "adult-specific" problem.

And, of course, make sure you installed the repo and downloaded data correctly :) Note, that we use torch==1.10.1+cu111.

As for -inf in sliced_logsumexp, you are right, but it should not affect the correctness of this function. Also, you may print unnormed_probs to check if it is Nan.

YiDongOuYang commented 2 years ago

I really appreciate your help!

I solve this issue by upgrade pytorch from 1.7.1 to 1.10. BTW, I have a small question that why we use sliced_logsumexp rather than torch.logsumexp used in Multinomial diffusion model?

Thank you again:D

rotot0 commented 2 years ago

About sliced_logsumexp:

So, we need to get probabilities from logits. The problem is we have multiple categorical features (let's say two), so we have to apply softmax independantly, e.g. vector [a1, a2, | b1, b2, b3] means that we have two cat features with 2 and 3 categories, and we want to apply softmax on [a1,a2] and [b1, b2, b3]. We could just iterate and use torch.logsumexp, but it was pretty slow in our experiments, so we've implemented sliced_logsumexp without loops.

YiDongOuYang commented 2 years ago

Very clear! Thank you very much for your explanation:)