dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.59k stars 483 forks source link

Can I use pytorch-like training process instead of scikit-compatible way of use this lib? #551

Open kaiwang0112006 opened 3 weeks ago

kaiwang0112006 commented 3 weeks ago

I want to do research on tabnet with federated learning, which means I need to get the model weight out and set it back during each epoch of training. It would be easier with a pytorch-like epoch training process of using this lib instead of scikit-compatible way of training.

Optimox commented 3 weeks ago

Hello, I am not sure that I understand your request. But if you want to use the tabnet network simply as a pytorch module and insert it inside your own pipeline you can simply used the modules from here: https://github.com/dreamquark-ai/tabnet/blob/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/pytorch_tabnet/tab_network.py#L508

kaiwang0112006 commented 3 weeks ago

That;s great! What should I give to the parameter "group_attention_matrix" ?

Optimox commented 3 weeks ago

This is an advanced feature, you can leave to None, otherwise you'll need to dig a bit into the code to use it. It's just a matrix of weights on how the attention can work across different features.