vanderschaarlab / synthcity

A library for generating and evaluating synthetic tabular data for privacy, fairness and data augmentation.
https://www.vanderschaar-lab.com/
Apache License 2.0
417 stars 55 forks source link

TabDDPM generating NaNs for large datasets #267

Open HLasse opened 5 months ago

HLasse commented 5 months ago

Question

When fitting TabDDPM for more than a single iteration NaNs are generated which leads to a ValueError during sampling here: https://github.com/vanderschaarlab/synthcity/blob/41e6e5acfd886dd4ebc0528039e9395a2a93b380/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py#L954-L955

Further Information

I'm fitting TabDDPM to a dataset of ~150k rows and ~1000 columns. All columns are numeric, contain no NaNs, and are scaled to z-scores. When training for more than a single iteration, I get the above error when calling model.generate(count=count, cond=cond). Any ideas about what might be happening? Seems to happen no matter what the other parameters are set to.

After subsampling to 15k rows, I was able to train for 1000 iterations, but only up to 500 timesteps. Might be some issues with larger datasets?

System Information

robsdavis commented 4 months ago

Hi @HLasse,

I have not seen this issue before. Have you tried experimenting with batch size to see if that get round the issue?

One other thing to check would be numerical instability. There are a few divisions and logarithms in the code here do any the denominators/log arguments of them approach zero for your dataset?

muellermarkus commented 3 months ago

I actually encountered the same problem on larger datasets. For me, it is not just that NaNs are sampled but the training loss also becomes NaN after a couple of iterations.

I agree with @robsdavis that this is related to numerical instability. In particular, I traced the error (in my case) down to

https://github.com/vanderschaarlab/synthcity/blob/943fa280687d236d783e53f40302838f5924f422/src/synthcity/plugins/core/models/tabular_ddpm/utils.py#L151-L154

I managed to stabilize the function and can create a related pull request if you want. However, I cannot guarantee that this gives the same results on the datasets for which the non-adjusted variant works without issues. I only tried this for a couple of datasets and the results were close enough.