dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.65k stars 488 forks source link

Row-wise sample weighting #563

Open alex-smith-bind opened 3 weeks ago

alex-smith-bind commented 3 weeks ago

Feature request

Sample (row-by-row) weighting.

What is the expected behavior?

Row-wise sample weights can be specified and applied during training.

What is motivation or use case for adding/changing the behavior?

This will allow one to provide weights based on business knowledge of confidence in the labels.

How should this be implemented in your opinion?

The weights can be applied in the loss function and metric function. The loss can be specified with reduction='none' so the weights can be applied per sample and then a weighted average across the batch is taken.

Mods are needed to accommodate the extra weight data in addition to the X and y data.

Are you willing to work on this yourself?

Yes. I have created a branch in a private copy of the repo and have a version of the logic nearly working. I would like to contribute this to the project because I believe it may be useful to others. I need access in order to push my branch so I can create a PR when it is ready. Is there anyone I could communicate with to discuss my approach? It looks like the Slack channel is defunct.

Optimox commented 3 weeks ago

you do not need any specific access to push a PR, I would be happy to review it.