XzwHan / CARD

Official PyTorch implementation for the paper "CARD: Classification and Regression Diffusion Models"
https://arxiv.org/abs/2206.07275
208 stars 38 forks source link

Question about categorical data #9

Closed fountaindive closed 1 year ago

fountaindive commented 1 year ago

Your paper is very interesting is really nice work!

I was wondering how you deal with categorical features, particular for classification problems? How do you handle these during data generation from the model?

Thanks!

XzwHan commented 1 year ago

Hi @fountaindive, thank you very much for checking out our work and for your nice words.

For "categorical features", did you mean categorical response variable? For classification problems, the label as the response variable has been one-hot encoded, while assumed to be situated in the real continuous space. Therefore, during data generation, we apply the same procedure for the classification problems as we do for regression problems (Algorithm 2 in our paper) -- see function p_sample_loop_with_eval within the evaluation function for image classification task test_image_task in the classification/card_classification.py file; y_0 variable corresponds to the reconstructed one-hot class label at timestep 0 of the reverse diffusion process.

After obtaining the generated labels, you could convert them to probability vectors ($i.e.$, all dimensions sum up to $1$) by using Eq. (10) in our paper, which is the softmax of a temperature-weighted Brier score. You could view this step as an alternative way of applying the softmax layer to the logits output of a deep neural net classifier to obtain probability output.

One extra note: for the metrics reported in our paper, the $t$-test results are computed with the probability vectors (check how we operate upon the gen_y_all_class_probs variable in function compute_and_store_cls_metrics), while the PIW results are computed in the space before applying the temperature scaling or the softmax function (raw_prob_val variable in function compute_and_store_cls_metrics), as we found that the interval width contrast is much more amplified in the raw space.

XzwHan commented 1 year ago

Meanwhile, we didn't apply CARD to other classification tasks besides image classification in our paper, and we plan to address these tasks in the future work. For UCI regression tasks, some datasets include categorical features, where we apply one-hot encoding to them before concatenating them to the numerical features (see the function onehot_encode_cat_feature in the regression/data_loader.py file).

fountaindive commented 1 year ago

Hi @XzwHan, sorry for not being more specific but I was wondering about categorical input features.

I'm coming from the angle of trying to build a generative model for tabular datasets with mixed features types i.e. both continuous and categorical features. A couple of papers I've found on dealing with categorical features are: 1) https://arxiv.org/abs/2006.09790 and 2) https://arxiv.org/abs/2102.05379. I think that's how I ended up finding your paper actually because you reference 2).

I really appreciate your description of your work thank you very much!

XzwHan commented 1 year ago

Thank you for referring to the two papers! Both GraphCNF and Argmax Flows appear to be aiming at generating or reconstructing the categorical features themselves, which is different from our work as we intend to incorporate categorical features into the diffusion model (the noise network $\boldsymbol{\epsilon{\theta}}$) through both $\boldsymbol{x}$ and $\boldsymbol{f{\phi}(x)}$ to better predict the response variable $\boldsymbol{y}$. We've come across another line of work that applies diffusion models to generate tabular data like https://arxiv.org/pdf/2209.15421.pdf, which might be related to the research problem you mentioned.

fountaindive commented 1 year ago

Ahh thanks for the link, I ended up finding it yesterday too! I think it actually implements the diffusion based approach from the argmax flows paper. I tried to use their (TabDDPM) code actually but it was very much a research code and not very functional. Thanks for the great dialogue about this!

All the best

XzwHan commented 1 year ago

Happy to engage in the discussion!