somepago / saint

The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
Apache License 2.0
402 stars 63 forks source link

Fixed [CLS] token during inference? #5

Closed dwang-sflscientific closed 3 years ago

dwang-sflscientific commented 3 years ago

Hi

For inference, the CLS token(L157 and L160 in train.py) is still basing on ground-truth label, should they be static CLS token instead?

dwang-sflscientific commented 3 years ago

For pretraining&fine-tuning, don't understand why ground truth labels are used as [CLS] token as well.

somepago commented 3 years ago

Hi, you don't need to base the CLS token on the actual label. It's an artifact from another project, the static token is generated in the embed_data_mask function (L321).

dwang-sflscientific commented 3 years ago

Ah I see. Thanks for the explaianation.