Open kaiwang0112006 opened 3 weeks ago
Hello, I am not sure that I understand your request. But if you want to use the tabnet network simply as a pytorch module and insert it inside your own pipeline you can simply used the modules from here: https://github.com/dreamquark-ai/tabnet/blob/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/pytorch_tabnet/tab_network.py#L508
That;s great! What should I give to the parameter "group_attention_matrix" ?
This is an advanced feature, you can leave to None, otherwise you'll need to dig a bit into the code to use it. It's just a matrix of weights on how the attention can work across different features.
I want to do research on tabnet with federated learning, which means I need to get the model weight out and set it back during each epoch of training. It would be easier with a pytorch-like epoch training process of using this lib instead of scikit-compatible way of training.