OHDSI / DeepPatientLevelPrediction

An R package for performing patient level prediction using deep learning in an observational database in the OMOP Common Data Model.
https://ohdsi.github.io/DeepPatientLevelPrediction
10 stars 4 forks source link

Implement TabNet #21

Open ChungsooKim opened 2 years ago

ChungsooKim commented 2 years ago

TabNet (paper) has been implemented in R recently (mlverse/tabnet, even it is based on torch package !), it is known for its good performance with tabular data. Therefore, it would be interesting to implement TabNet and compare the model performance with the traditional algorithm (lasso logistic regression, xgboost).

image

egillax commented 2 years ago

Good idea! Do you want to take a stab at it? They have a pretty high level api. I can also help.

ChungsooKim commented 2 years ago

Great, I think we can implement this into our package. I'll try to write generic codes for model fitting with the tabNet functions. here is a video for a tutorial.

ChungsooKim commented 2 years ago

Hi @jreps, @egillax I wrote generic codes for TabNet in the develop branch. Could you review the codes and give any comments on this? After this, I think we can compare its performances to the traditional algorithms or other deep learning algorithms like the Resnet.

egillax commented 2 years ago

Hi @ted9219 ,

Great job! I looked at the code and made small changes to get it to work on my end. I also added things like saving the model and feature importances. And verified it works and displays correctly in the shiny results viewer. I pushed it into develop.

There are though a few things that worry me.

ChungsooKim commented 2 years ago

Thanks, @egillax . I totally agree with your comments.

By benchmarking this repo, can we develop modules that implement the tabNet ourselves?

I also found that the performance is lower than expected and lower than other modules like ResNet. It needs to be figured out why the performance is always low. I'll try.

egillax commented 2 years ago

Just noting down a few observations I've made.

The output of the model for me was weird with the default loss function. It was logits (it had some negative values) but the mean of the output was the outcome rate, suggesting the loss function used is treating the output as probabilities. So I switched to using loss <- torch::nn_bce_with_logits() in my config (it's the same loss function as used as default in my estimator class). Then the output makes sense after I use torch::nnf_sigmoid() to convert the logits to probabilities, as in the mean is the outcome rate and the distribution is what I would expect.

Another thing, it seems the columns in the input need to be factors otherwise it treats them as numerical and skips the embedding layer. I think this could be the reason for the poor performance. I tried converting the columns to factors with trainDataset[] <- lapply(trainDataset, factor) but that takes a lot of time and then the training is extremely slow. After running it overnight I'm still not getting better auc's.

I suspect the best move forward would be to use the module you linked from the tabnet repo but possibly rewriting the embedding layer to deal with our case where almost all variables are binary. I did already start implementing another transformer model from scratch and in the process made the estimator and the fit function more general, so in the future I think to add models you would only need to create a setModel function.

I will commit those changes later today but I was stuck on the embedding layer with the same problem as for the tabnet, how to efficiently create an embedding layer for a matrix with binary features. So when I solve that I might have a solution for tabnet as well.