dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.6k stars 482 forks source link

customize training loop and dataloaders #364

Closed bstockton closed 2 years ago

bstockton commented 2 years ago

Feature request

Create an example of how custom training loops and dataloaders should be done with this library. Alternatively or additionally make helper methods that are designed to make custom training loops easier

What is the expected behavior? Something similar to the PyTorch nn.Transfomer base model and tutorial. https://pytorch.org/tutorials/beginner/translation_transformer.html

What is motivation or use case for adding/changing the behavior? While the ability to just call fit on a model is nice, one of the biggest draws to PyTorch is how easy and elegant it is to write your training loop. Custom training loops makes small tweaks a lot easier, for example right custom dataloaders.

How should this be implemented in your opinion? Similar to this example https://pytorch.org/tutorials/beginner/translation_transformer.html

Are you willing to work on this yourself? yes, but it seems there should be a "preferred" way given the way the library is written and that seems more appropriate to be done by the authors.

Optimox commented 2 years ago

Well, pytorch-tabnet basically consists in two things : the networks written in pytorch, and the encapsulated training loop, predicitons etc... making it user friendly. Also you can already change quite a lot of the training process (class weights, optimizer, early stopping, scheduler, learning rate etc...). So I think it's reasonable to say that you can already customize your training loop.

If you want to go on a full custom training loop, you can always write your custom code and simply use the network architecture by calling from pytorch-tabnet.tab_network import TabNet.

I'm happy to discuss on more specific needs to improve the customization experience but your request seems to general at the moment.