dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

The mechanism of ClassificationSMOTE() #491

Closed Yuntian9708 closed 1 year ago

Yuntian9708 commented 1 year ago

Hi, can somebody help explain the mechanism of ClassificationSMOTE()? I saw some answers from previous closed questions, explaining that SMOTE can synthetic some minorities from real-world data. But after applying ClassificationSMOTE(), I didn't see any synthetic data, and the length of the data didn't change. I check the codes of ClassificationSMOTE(), and the comments from the author said “This will average a percentage p of the elements in the batch with other elements. The target will stay unchanged and keep the value of the most important row in the mix.” I think it's more like a type of data normalization for skewed data. @Optimox

Optimox commented 1 year ago

@Yuntian9708 SMOTE is a type of data augmentation, it will run on the fly and modify the inputs the model sees during training. It consists for classification in averaging two random rows to create a new one: arow_1+(1-a)row_2=new_row, note that the final class here is the class of row_1 if a>0.5 and the class of row_2 if a<0.5 (it would also be possible to average the classes too).

In that way new_row is a new random data point that lies in the linear space of your data points.

When using classifcationSMOTE you should only monitor whether or not it improves your score on validation, nothing else will be visible to you as the augmentation is done in the background by batch and in a stochastic way.

Yuntian9708 commented 1 year ago

OK, thank you for replying. I may understand it now. It actually improves the performance of my model. I just want to know how to explain it from a mathematical view. Let's say, we have a random sample A that needed to be augmented. I average A with sample B and got sample C, then, I replaced A with C. That's the whole process. A comment in source code: # Ensure that the first element to switch has a probability > 0.5. meaning that C and A are the same class (augmentation doesn't change the label in the code implementation.) The parameter "a" you mentioned above is a probability that decided the new sample's class. Did I understand it correctly? btw, what's the purpose of ClassificationSMOTE()? sounds like it reduces the data skew. @Optimox

Optimox commented 1 year ago

When training a neural network you train by epochs which often have the size of your training data.

If you do 20 epochs this means that you showed exactly 20 times every data points to your model in order to train it. When using classification smote, you generate randomly "new" points which will essentially be unique and be seen only once by your model. That's a way of augmenting your training data to avoid overfitting and hopefully increase your model's generalization capacity.

The parameter "a" is a random probability that will decide how close you are from row_1, the label will be the label of the closest point.

Yuntian9708 commented 1 year ago

So I should apply smote for each training batch, not for all training data before training? Or both ways are correct?

Optimox commented 1 year ago

You do not need to do anything yourself, simply add it to the fit parameter named augmentations : you have an example on how to use it in this notebook : https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb

Yuntian9708 commented 1 year ago

Yes, I understand, but I also want to apply it to other baseline models. I just found that If I apply it to all training data before starting training, I can get abetter results, comparing to applying it to each batch.

Yuntian9708 commented 1 year ago

It seems to apply this augmentation for the whole training set is a more reasonable way for my given task, since I get a better testing result. Instead, if I apply it for each batch, I can't reproduce the result if I re-run the training program multiple times. I guess it randomly augmented some samples every time I run the program. Anyway, thank you for your explanation and patience. I may keep my current implementations.