manujosephv / pytorch_tabular

A standard framework for modelling Deep Learning Models for tabular data
https://pytorch-tabular.readthedocs.io/
MIT License
1.26k stars 131 forks source link

Multi-Target Classification #430

Open YonyBresler opened 3 months ago

YonyBresler commented 3 months ago

Is your feature request related to a problem? Please describe. Currently, regression allows for multi-target training, while classification only supports a single target.

Describe the solution you'd like Add support for multi-target classification, including data handling, output layer size shaping, and appropriate loss.

Describe alternatives you've considered You currently need to train multiple models, which is both slower and may not get the same performance as learning a common embedding

Additional context I'm prepared to contribute this feature myself, but wanted to check:

  1. That there is no unworkable/undesirable roadblock that prevents this from being possible at this time
  2. That if I were to complete it, that the community is interested in integrating this feature
manujosephv commented 2 months ago

@YonyBresler This has been something the community has requested for in the past as well and would be an awesome addition. I can assure you that your contribution would be most welcomed. I can work with you for the PR.

YonyBresler commented 2 months ago

Thanks for your feedback @manujosephv , I finished a first pass at this, and have a PR in #441 , happy to hear your (or any other community member) feedback