dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

Scikit-Learn Compatibility #509

Open enesgencer18 opened 11 months ago

enesgencer18 commented 11 months ago

Feature request

I would like to include self._estimator_type = 'classifier'.

    def __post_init__(self):
        super(TabNetClassifier, self).__post_init__()
        self._task = 'classification'
        self._default_loss = torch.nn.functional.cross_entropy
        self._default_metric = 'accuracy'
        self._estimator_type = 'classifier'
....

What is the expected behavior? The expected behavior ensuring better scikit-learn compatibility of the TabNetClassifier() so that TabNetClassifier() can be used in StackingClassifier() or VotingClassifier().

What is motivation or use case for adding/changing the behavior? Otherwise, we received the following error.

ValueError: The estimator TabNetClassifier should be a classifier.

Are you willing to work on this yourself? Yes