basf / mamba-tabular

Mambular is a Python package that brings the power of Mamba architectures to tabular data, offering a suite of deep learning models for regression, classification, and distributional regression tasks. This includes models like Mambular, FT-Transformer, TabTransformer and tabular ResNets.
https://mambular.readthedocs.io
MIT License
135 stars 7 forks source link

[FAQ] On-training custom metrics #161

Open LordGedelicious opened 2 days ago

LordGedelicious commented 2 days ago

I'm wondering if there's a way to customize the metrics logged per step and per epoch of the model training process. Right now, I'm training a binary classifier model (MambularClassifier) and it's only showing the train_loss_step, val_loss, and train_loss_epoch for the metrics. Is it possible to add other metrics such as MSE, RMSE, accuracy, val_accuracy? Other than extending the class implementation myself of course.

AnFreTh commented 2 days ago

Great Suggestion! We will add it in one of the next versions. I will transform this to an "enhancement"

LordGedelicious commented 2 days ago

Thank you for your hard work! Quick follow up question if you don't mind, since the models are implemented based on Sklearn's base models, can mapping learning curves and/or validation curves be performed with the from sklearn.model_selection import learning_curve functionality? Right now, I'm noting the values manually since it takes 3-4 minutes per epoch to be mapped later on with matplotlib

AnFreTh commented 2 days ago

Since all models are ultimately training via lightning, this is unfortunately not possible. Once, logging custom metrics is possible, everything can logged via the trainer kwargs in the fit method