dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.61k stars 485 forks source link

Transfer learning through TabNet #426

Closed yashGuleria closed 2 years ago

yashGuleria commented 2 years ago

Feature request : For problems such as sequential prediction tasks, it is possible to use XGBoost in a way that for the second predicition we can use the previously trained model (from the XGB.fit parameters). Is it possible to do something similar with tabnet?

What is the expected behavior? This would enable sequential prediction (like a continual learnign appraoch)

What is motivation or use case for adding/changing the behavior?

How should this be implemented in your opinion?

Are you willing to work on this yourself? yes

Optimox commented 2 years ago

Currently you can simply call fit twice to retrain your model on new data. Is that what you want ?

yashGuleria commented 2 years ago

Currently you can simply call fit twice to retrain your model on new data. Is that what you want ? Let me explain by an example: if I have data from 2 test subject who were asked to perform the same task (separately), I could train an XGBoost model on the data of the first subject (cal this model M1) and then use this trained model as the base model to train on the data of the second subject (model M2) and check if the knowledge from the first subject helps in increasing the performance of the second model(M2). In XGBoost I can set the parameter with the "M2.fit(X_train, Y_train, xgb_model = M1)" Would it be feasible with tabnet?

Optimox commented 2 years ago

warm_start only exist on the development branch, it has not been released yet.

With the current release, you always have warm_start=True.

So what you can do is the following (pseudo-code):

tabnet_model_1 = TabNetClassifier(my_params.copy())
tabnet_model_1.fit(data_1)

tabnet_model_2 = TabNetClassifier(my_params.copy())
tabnet_model_2.fit(data_2)

tabnet_model_1.fit(data_2)

Then you'll be able to see if training tabnet_model_1 on data_1 then data_2 brings an uplift compared to tabnet_model_2 only trained with data_2

yashGuleria commented 2 years ago

Thank you for the suggestion. I will close this issue now.