ersilia-os / ersilia

The Ersilia Model Hub, a repository of AI/ML models for infectious and neglected disease research.
https://ersilia.io
GNU General Public License v3.0
203 stars 131 forks source link

📑 Feature Request: Using Google's TabNet for Better Performance and Interpretable Results #587

Closed Femme-js closed 1 year ago

Femme-js commented 1 year ago

Is your feature request related to a problem? Please describe.

Most of the AI/ML models incorporated by Ersilia with different data domains, output formats, and prediction types leveraged great assistance from machine learning algorithms and ensemble models.

For Tabular data, ensemble models (like XGBoost) outperform many deep learning models and win the tradeoff of explainability and complex architecture. Explainability is necessary to understand what is important to the model and why model behaves in a certain way. However, there is still a scope to capture the performance boost leveraged by deep learning architectures.

Describe the solution you'd like.

Google's TabNet was proposed in 2019 with the idea of effectively using deep neural networks for tabular data.

TabNet is a complex model composed of a feature transformer, attentive transformer, and feature masking, that soft feature selection with controllable sparsity in end-to-end learning. The reason for the high performance of TabNet is that it focuses on the most important features that have been considered by the Attentive Transformer. The Attentive Transformer performs feature selection to select which model features to reason from at each step in the model, and a Feature Transformer processes feature into more useful representations and learn complex data patterns, which improve interpretability and help it learn more accurate models.

The key benefit TabNet has is its explainability, as it uses instance-wise feature selection using masks in its encoder. This allows the model’s learning capacity to be focused on important features. It can capitalize on the forthcoming visualization of masks (as it provides explainability) which can be experimented and which features are being used at a prediction level can be explored.

Describe alternatives you've considered

No response

Additional context.

No response

GemmaTuron commented 1 year ago

This is being tackled in this repository so I'll close this issue