jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.27k stars 188 forks source link

Problems running transformer models #199

Open rruizdeaustri opened 7 months ago

rruizdeaustri commented 7 months ago

Hello,

I'm trying to classify events for a dark matter direct detection experiment which are tabulated in some optimal features (data are continuous). When I run both xgboost and lgbm algorithms I get AUCs about 0.98. When I run an MLP model (without optimisation) I get about 0.93 which is a bit far from the decision trees but maybe this is the best one can get with an MLP. The issue comes with the transformer models. From those I get just like a random classifier (~0.5) so there must be wrong in my script but it is not obvious to me identify the issue. Could you pls have a look at my script and tell me if you see something wrong ? This is my script:

import numpy as np import torch import pandas as pd

from pytorch_widedeep.initializers import XavierNormal from pytorch_widedeep import Trainer from pytorch_widedeep.models import ( SAINT, Wide, WideDeep, TabPerceiver, FTTransformer, TabFastFormer, TabTransformer, ) from pytorch_widedeep.metrics import Accuracy, Precision from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.callbacks import ( LRHistory, EarlyStopping, ModelCheckpoint, ) from pytorch_widedeep.preprocessing import TabPreprocessor from pytorch_widedeep.initializers import XavierNormal, KaimingNormal

from torchmetrics import AUROC

from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, roc_auc_score, roc_curve

use_cuda = torch.cuda.is_available()

if name == "main":

csv_file_path = '/lustre/ific.uv.es/ml/ific005/projects/direct_detection/data_dd/events2024/processed/combined_data.csv'

# Load the CSV file into a DataFrame
df = pd.read_csv(csv_file_path)

print(df.head())

continuous_cols = ['pA_S1','pH_S1','pHT_S1','pL_S1','pL90_S1','pRMSW_S1','pHTL_S1','pA_S2','pH_S2','pHT_S2','pL_S2','pL90_S2','pRMSW_S2','pHTL_S2','pbot', 'ptop','pdiffT']

target = "Label"

df_train, df_valid = train_test_split(
    df, test_size=0.2, stratify=df[target], random_state=1
)

df_valid, df_test = train_test_split(
    df_valid, test_size=0.5, stratify=df_valid[target], random_state=1
)

tab_preprocessor = TabPreprocessor(
    continuous_cols=continuous_cols,
    scale=True,
    with_attention=True,
    )

X_tab_train = tab_preprocessor.fit_transform(df_train)
X_tab_valid = tab_preprocessor.transform(df_valid)
X_tab_test = tab_preprocessor.transform(df_test)

# target
y_train = df_train[target].values
y_valid = df_valid[target].values
y_test = df_test[target].values

wide = Wide(input_dim=np.unique(X_tab_train).shape[0], pred_dim=1)

tab_transformer = TabTransformer(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=continuous_cols,
    embed_continuous=True,
    n_blocks=4    
)

saint = SAINT(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=continuous_cols,
    cont_norm_layer="batchnorm",
    n_blocks=4,
)

tab_perceiver = TabPerceiver(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=continuous_cols,
    n_latents=6,
    latent_dim=16,
    n_latent_blocks=4,
    n_perceiver_blocks=2,
    share_weights=False,
)

tab_fastformer = TabFastFormer(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=continuous_cols,
    n_blocks=4,
    n_heads=4,
    share_qv_weights=False,
    share_weights=False,
)

ft_transformer = FTTransformer(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=continuous_cols,
    input_dim=32,
    kv_compression_factor=0.5,
    n_blocks=3,
    n_heads=4,
)

for tab_model in [
    tab_transformer,
    saint,
    ft_transformer,
    tab_perceiver,
    tab_fastformer,
]:
    model = WideDeep(deeptabular=tab_model, pred_dim=1)

    wide_opt = torch.optim.Adam(model.parameters(), lr=0.01)
    deep_opt = torch.optim.Adam(model.parameters(), lr=0.01)
    wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)
    deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)

    optimizers = {"wide": wide_opt, "deeptabular": deep_opt}
    schedulers = {"wide": wide_sch, "deeptabular": deep_sch}
    initializers = {"wide": KaimingNormal, "deeptabular": XavierNormal}
    callbacks = [
        LRHistory(n_epochs=10),
        EarlyStopping(patience=5),
        ModelCheckpoint(filepath="model_weights/wd_out"),
    ]
    metrics = [Accuracy]

    trainer = Trainer(
        model,
        objective="binary",
        optimizers=optimizers,
        lr_schedulers=schedulers,
        initializers=initializers,
        callbacks=callbacks,
        metrics=metrics,
    )

    trainer.fit(
     X_train={"X_tab": X_tab_train, "target": y_train},
     X_val={"X_tab": X_tab_valid, "target": y_valid},
     n_epochs=10,
     batch_size=100,
    )

    df_pred = trainer.predict(X_tab=X_tab_test)

    print(classification_report(df_test[target].to_list(), df_pred))
    #print("Actual predicted values:\n{}".format(np.unique(df_pred, return_counts=True)))
    auc = roc_auc_score(df_test[target], df_pred)

    print('AUC', auc)

Thanks a lot !

Roberto

jrzaurin commented 7 months ago

Hey @rruizdeaustri

I will have a more detail look, but in general here are some comments:

  1. Use a simpler model, forget about the wide component and use simply a deeptabular component with defaults. (review the code in your example since the optimizers and schedures are not correctly defined. The Trainer not throwing an error is intentional, I might change it, but just define your Trainer as
    trainer = Trainer(
        model,
        objective="binary",
        callbacks=[ModelCheckpoint(filepath="model_weights/wd_out")],
        metrics=[Accuracy],
    )
  1. The results with Transformer based models depend A LOT on the parameters, far more than in GBMs, where all, XGBoost, LightGBM and CatBoost perform almost to their best performance out of the box. You could have a look to this relatively old post see if it helps

I hope this helps and let me know how you get on with this, see if I can help more

rruizdeaustri commented 7 months ago

Hi @jrzaurin,

I have made the modifications you suggested and results make more sense now. I'm optimising hyper-parameters with optima in resnet and transformer models but the results are far from the one got with LightGMB: AUC ~0.93 versus ~ 0.98 for lgqbm

Thanks !

Rbt

jrzaurin commented 7 months ago

Hey @rruizdeaustri

Thanks for sharing the results :)

0.05 is perhaps a bit too much, maybe I can look at some examples if you would be willing to share them. However, I am afraid that this is the "brutal" reality for most (true) real world cases when it comes to DL vs GBMs.

You could try some other libraries see if their implementations are better or you get better results (?)

In my experience I have used DL for tabular data in a few occasions, but never aimed to beat GBMs, since I knew was a lost battle.

rruizdeaustri commented 7 months ago

Hi @jrzaurin,

Yes, these are too much differences !

I could share with you the files I'm using to train as well as the data if you like. Let me know !

Thanks !

jrzaurin commented 6 months ago

Hey @rruizdeaustri !

I am traveling at the moment, but if you join the slack channel we can move the conversation there and we can share the files. See if I have the time to give it a go myself! :)

Thanks!