yandex-research / tab-ddpm

[ICML 2023] The official implementation of the paper "TabDDPM: Modelling Tabular Data with Diffusion Models"
https://arxiv.org/abs/2209.15421
MIT License
370 stars 81 forks source link

Can't get the generated categorical values according to the README. #18

Open koseoyoung opened 1 year ago

koseoyoung commented 1 year ago

Hi, I'm trying to generate the synthetic data by following the README. The training works well, and I'm able to find the numerical generated values, but I can't get the categorical generated values. In detail,exp/churn2/ddpm_mlp_best/X_cat_train.npy is empty. On the other hand, I can see the values in exp/churn2/ddpm_mlp_best/X_num_train.npy. Is this a bug, or am I missing anything? Thanks!

python scripts/pipeline.py --config exp/churn2/ddpm_mlp_best/config.toml --train --sample 

[3 2 2 2]
16
{'num_classes': 2, 'is_y_cond': True, 'rtdl_params': {'d_layers': [512, 1024, 1024, 1024, 1024, 256], 'dropout': 0.0}, 'd_in': 16}
mlp
Step 500/30000 MLoss: 0.713 GLoss: 0.9151 Sum: 1.6280999999999999
Step 1000/30000 MLoss: 0.7205 GLoss: 0.6157 Sum: 1.3362
Step 1500/30000 MLoss: 0.7073 GLoss: 0.5363 Sum: 1.2436
Step 2000/30000 MLoss: 0.718 GLoss: 0.506 Sum: 1.224
Step 2500/30000 MLoss: 0.7069 GLoss: 0.484 Sum: 1.1909
Step 3000/30000 MLoss: 0.7195 GLoss: 0.4682 Sum: 1.1877
Step 3500/30000 MLoss: 0.721 GLoss: 0.4381 Sum: 1.1591
Step 4000/30000 MLoss: 0.7193 GLoss: 0.4203 Sum: 1.1396000000000002
Step 4500/30000 MLoss: 0.7173 GLoss: 0.412 Sum: 1.1293
Step 5000/30000 MLoss: 0.7159 GLoss: 0.4084 Sum: 1.1242999999999999
Step 5500/30000 MLoss: 0.7206 GLoss: 0.4068 Sum: 1.1274
Step 6000/30000 MLoss: 0.7188 GLoss: 0.4046 Sum: 1.1234
Step 6500/30000 MLoss: 0.7161 GLoss: 0.3968 Sum: 1.1129
Step 7000/30000 MLoss: 0.7148 GLoss: 0.3964 Sum: 1.1112
Step 7500/30000 MLoss: 0.7152 GLoss: 0.3985 Sum: 1.1137
Step 8000/30000 MLoss: 0.7154 GLoss: 0.3928 Sum: 1.1082
Step 8500/30000 MLoss: 0.6898 GLoss: 0.3933 Sum: 1.0831
Step 9000/30000 MLoss: 0.7257 GLoss: 0.3968 Sum: 1.1225
Step 9500/30000 MLoss: 0.7103 GLoss: 0.3969 Sum: 1.1072
Step 10000/30000 MLoss: 0.7095 GLoss: 0.3844 Sum: 1.0939
Step 10500/30000 MLoss: 0.7171 GLoss: 0.3905 Sum: 1.1076
Step 11000/30000 MLoss: 0.7225 GLoss: 0.3903 Sum: 1.1128
Step 11500/30000 MLoss: 0.6993 GLoss: 0.3919 Sum: 1.0912000000000002
Step 12000/30000 MLoss: 0.7088 GLoss: 0.3844 Sum: 1.0932
Step 12500/30000 MLoss: 0.726 GLoss: 0.385 Sum: 1.111
Step 13000/30000 MLoss: 0.7013 GLoss: 0.3832 Sum: 1.0845
Step 13500/30000 MLoss: 0.7045 GLoss: 0.382 Sum: 1.0865
Step 14000/30000 MLoss: 0.6991 GLoss: 0.3807 Sum: 1.0798
Step 14500/30000 MLoss: 0.7229 GLoss: 0.3812 Sum: 1.1040999999999999
Step 15000/30000 MLoss: 0.7096 GLoss: 0.3825 Sum: 1.0921
Step 15500/30000 MLoss: 0.7081 GLoss: 0.3788 Sum: 1.0869
Step 16000/30000 MLoss: 0.7158 GLoss: 0.3816 Sum: 1.0974
Step 16500/30000 MLoss: 0.7067 GLoss: 0.377 Sum: 1.0836999999999999
Step 17000/30000 MLoss: 0.6844 GLoss: 0.3759 Sum: 1.0603
Step 17500/30000 MLoss: 0.699 GLoss: 0.373 Sum: 1.072
Step 18000/30000 MLoss: 0.692 GLoss: 0.3802 Sum: 1.0722
Step 18500/30000 MLoss: 0.6822 GLoss: 0.3725 Sum: 1.0547
Step 19000/30000 MLoss: 0.7103 GLoss: 0.3775 Sum: 1.0878
Step 19500/30000 MLoss: 0.7161 GLoss: 0.373 Sum: 1.0891
Step 20000/30000 MLoss: 0.6964 GLoss: 0.3737 Sum: 1.0701
Step 20500/30000 MLoss: 0.6908 GLoss: 0.3772 Sum: 1.068
Step 21000/30000 MLoss: 0.687 GLoss: 0.3743 Sum: 1.0613000000000001
Step 21500/30000 MLoss: 0.6917 GLoss: 0.3731 Sum: 1.0648
Step 22000/30000 MLoss: 0.6878 GLoss: 0.3727 Sum: 1.0605
Step 22500/30000 MLoss: 0.7067 GLoss: 0.3743 Sum: 1.081
Step 23000/30000 MLoss: 0.6958 GLoss: 0.3685 Sum: 1.0643
Step 23500/30000 MLoss: 0.7281 GLoss: 0.3687 Sum: 1.0968
Step 24000/30000 MLoss: 0.7148 GLoss: 0.3705 Sum: 1.0853
Step 24500/30000 MLoss: 0.7021 GLoss: 0.3696 Sum: 1.0716999999999999
Step 25000/30000 MLoss: 0.6956 GLoss: 0.3667 Sum: 1.0623
Step 25500/30000 MLoss: 0.6925 GLoss: 0.3654 Sum: 1.0579
Step 26000/30000 MLoss: 0.6943 GLoss: 0.3682 Sum: 1.0625
Step 26500/30000 MLoss: 0.6899 GLoss: 0.3646 Sum: 1.0545
Step 27000/30000 MLoss: 0.6971 GLoss: 0.36 Sum: 1.0571000000000002
Step 27500/30000 MLoss: 0.7294 GLoss: 0.3662 Sum: 1.0956000000000001
Step 28000/30000 MLoss: 0.7321 GLoss: 0.3631 Sum: 1.0952
Step 28500/30000 MLoss: 0.6992 GLoss: 0.3626 Sum: 1.0618
Step 29000/30000 MLoss: 0.7026 GLoss: 0.3621 Sum: 1.0647
Step 29500/30000 MLoss: 0.711 GLoss: 0.3618 Sum: 1.0728
Step 30000/30000 MLoss: 0.6822 GLoss: 0.3598 Sum: 1.042
mlp
Sample timestep    0
Sample timestep    0
Sample timestep    0
Sample timestep    0
Sample timestep    0
Sample timestep    0
Discrete cols: [2, 4]
Num shape:  (52000, 7)
Elapsed time: 0:13:52
koseoyoung commented 1 year ago

Previously it has such error:

 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [Unable to show a serialized python object.]]

The root cause of this issue is that np.save requires np data type, which is str when we save the generated cat file into the disk. I fixed it through this PR. #19

After this fix:

[['Spain', '1', '1', '0'],
 ['Germany', '0', '1', '0'],
 ['Germany', '0', '1', '0'],
 ['Germany', '1', '1', '0'],
 ['Germany', '1', '0', '0'],
 ['Spain', '0', '1', '1'],
 ['Spain', '1', '0', '0'],
 ['Germany', '0', '1', '1'],
 ['Spain', '0', '1', '0'],
 ['Spain', '1', '1', '0'],
 ['Spain', '1', '1', '0'],
 ['France', '1', '0', '0'],
 ['Germany', '1', '1', '0'],
 ['Germany', '0', '1', '0'],
 ['Spain', '0', '0', '1'],
 ['France', '1', '0', '0'],
 ['France', '0', '1', '1'],
 ['Spain', '1', '1', '0'],
 ['Spain', '1', '1', '0'],
 ['France', '0', '0', '1'],
 ['France', '0', '1', '0'],
 ['Spain', '0', '1', '0'],
 ['Germany', '0', '1', '0'],
JiangLei1012 commented 4 months ago

Hello, could you tell me the detiles of your solution? Thank you so much!