dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.59k stars 484 forks source link

Research : Binary Mask vs Sparse Mask? #125

Open Optimox opened 4 years ago

Optimox commented 4 years ago

Main Remark

Tabnet architecture is using sparsemax function in order to perform instance-wise feature selection, and this is one of the important feature of TabNet.

One of the interesting properties of sparsemax is that it's outputs sum to 1, but do we really want this? Is it the role of the mask to perform both selection (0s for unused features) and importance (a value between 0 and 1)? I would say that the feature transformer should be used to create importance (by summing values of the relu outputs as it's done in the paper) and the masks should output binary masks that would not sum to 1.

On problem I see with non binary maks is that they change the values for the next layers, if someone is 50 year old, and the attention layer think that age is half of the solution then attention for age would be 0.5, and the next layer would see age=25. But how can the next layers differentiate from 75 / 3, 50 /2 and 25? They can't really, so it seems that some information is lost along the way because of the masks, that's why I would be interested to see how binary masks perform!

Proposed Solutions

I'm not quite sure if there are known solutions for this, would thresholding a softmax works? Would you add this threshold as a parameter? or would it be learnt by the model itself? I'm not even sure that it would

If you feel like this is interesting and would like to contribute, please share your ideas in comments or open a PR!

pangjac commented 2 years ago

Hi @Optimox , could you please help elaborate the example

if someone is 50 year old, and the attention layer think that age is half of the solution then attention for age would be 0.5, and the next layer would see age=25. But how can the next layers differentiate from 75 / 3, 50 /2 and 25?

I am a bit confused on how, once the attention layer thinks age is half of the solution, then "the next layer" would see age=25.( how 25 comes out?) Thank you!

Optimox commented 2 years ago

Well, once the attention mask is applied if you multiply age by 0.5 then it gives you a totally different age. In practice it still works but I wonder if it would work better with completely binary masks. That’s the point.

I once tried to add a very sharp layer that’s 0 in 0 but goes up to 1 very quickly, but I remember that it did not change much (and gradients exploded). It would be nice to perform an exhaustive comparison on multiple benchmark datasets.

haoliangjiang commented 2 years ago

The binary mask will help in that we are confident that the attention mask allows the entire value to join the computation of the next block. However, I believe one of the implicit goals is that network can learn how to pass information so that it can classify or regress correctly. So I do not think binary masks will significantly boost the performance, for both masks are more about introducing more visualizability of the network. But I am not sure about the optimization level where it might do something different to gradients.

Given a well-trained model, I can think of two situations that the network outputs 25 instead of 50 after the attention. One is that the input data is noisy data. It corrects the data by changing the age to a reasonable range. The other is that the model knows 25 is 50 somehow, as the attention mask is mainly based on the input. Both of these situations help to predict.

Not a pro. Just sharing my thoughts.