jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 190 forks source link

SAINT Model #38

Closed cmcmaster1 closed 3 years ago

cmcmaster1 commented 3 years ago

This new model would be great to implement: https://arxiv.org/abs/2106.01342

They mention that they even found performance improvements in the TabTransformer architecture by embedding continuous variables (variable-specific linear layer to go form 1 -> n_embeddings). I have made a fork and implemented this as an option in _tabtransformer. Okay to make a PR?

jrzaurin commented 3 years ago

hey @cmcmaster1 , Sorry it took me some time to reply!

We are indeed familiar with SAINT, the main author said in twitter they were going to release the code soon, so we were waiting. I have a few concerns...but we will try to bring it to the package as well as the model that won the tabular playground competition.

I am really busy at work these days so things are getting a bit slow on developing the package, but I am going to get some help as there are a number of coming improvements and a post benchmarking algos (see here https://github.com/jrzaurin/tabulardl-benchmark) be careful! code is ugly at times :)

And regarding the PR, BY ALL MEANS! please, thank you! Branch out from the branch tabnet as I am about to release v1 and comes with tabnet adapted to wide and deep

Maybe we can chat and you can also update docs?

let me know, but yes, PR!

jrzaurin commented 3 years ago

@cmcmaster1 saint code released: https://github.com/somepago/saint

Now is just a matter of time. If you want to help you are more than welcome 😉

jrzaurin commented 3 years ago

SAINT is now added