jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 191 forks source link

Multitargets regression "val_split" issue #233

Closed altar31 closed 1 month ago

altar31 commented 1 month ago

Hi,

In the multitargets regression setting, when I try to set a val_split in the trainer.fit() methods, it rises me the following error.

Capture d’écran 2024-09-19 à 18 30 04

As a direct consequence, it is not possible to use the LR scheduler as shown below

Capture d’écran 2024-09-19 à 18 34 09

Please find the requirements.txt and the code in order to reproduce this issue. requirements.txt

import random
import pandas as pd
import numpy as np
import pickle
import torch
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.models import TabMlp, WideDeep
from pytorch_widedeep.losses_multitarget import MultiTargetRegressionLoss
from pytorch_widedeep import Trainer
from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import AdamW

# generate data
data = {
    "X1": [random.uniform(200, 250) for _ in range(1000)],
    "X2": [random.uniform(100, 200) for _ in range(1000)],
    "X3": [random.uniform(0, 15000) for _ in range(1000)],
    "Y1": [random.uniform(100, 2000) for _ in range(1000)],
    "Y2": [random.uniform(40, 100) for _ in range(1000)],
}
df = pd.DataFrame(data)

test_data = {
    "X1": [random.uniform(220, 250) for _ in range(1000)],
    "X2": [random.uniform(156, 200) for _ in range(1000)],
    "X3": [random.uniform(0, 15000) for _ in range(1000)],
}
df_test = pd.DataFrame(test_data)

# set the target
target = df[["Y1", "Y2"]].values.astype(np.float32)
# Tabular preprocessor
tab_preprocessor = TabPreprocessor(continuous_cols=["X1", "X2", "X3"])
X_tab = tab_preprocessor.fit_transform(df).astype(np.float32)
# Model architecture
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)
model = WideDeep(deeptabular=tab_mlp, pred_dim=2)
print("The model architecture:\n", model)

# Multitargets loss
loss = MultiTargetRegressionLoss(weights=[0.5, 0.5], reduction="mean")
# early stopping
early_stopping = EarlyStopping()

# checkpoints
model_checkpoint = ModelCheckpoint(
    # filepath=f"model_check",
    save_best_only=True,
    verbose=1,
)

# optimizer
deep_opt = AdamW(model.deeptabular.parameters(), lr=0.001)

# lr scheduler
deep_sch = ReduceLROnPlateau(
    deep_opt,
    patience=2,
    min_lr=1e-5,
)
# Set the trainer
trainer = Trainer(model, 
    objective="multitarget", 
    custom_loss_function=loss,
    callbacks=[early_stopping, model_checkpoint],
    lr_schedulers={"deeptabular": deep_sch},
    optimizers={"deeptabular": deep_opt},   
)
# Fit the model
trainer.fit(
    X_tab=X_tab,
    target=target,
    n_epochs=10,
    batch_size=16,
    val_split=0.2
)

# save the model
torch.save(model.state_dict(), "multi_regression_torch.pt")
# save preprocessors
with open("multi_regression_.pkl", "wb") as dp:
    pickle.dump(tab_preprocessor, dp)
print("Training completed and model saved.")

# Load the saved preprocessor
with open(f"multi_regression_.pkl", "rb") as tp:
    tab_preprocessor_new = pickle.load(tp)

# Load the trained model for inference on test data
X_test = tab_preprocessor_new.transform(df_test).astype(np.float32)
new_model = WideDeep(deeptabular=tab_mlp, pred_dim=2)
new_model.load_state_dict(torch.load(f"multi_regression_torch.pt"))
new_model.eval()

# Use the trained model for prediction
trainer_new = Trainer(model, objective="multitarget", custom_loss_function=loss)
preds = trainer_new.predict(X_tab=X_test, batch_size=64)
print("Predictions array:\n", preds)

Thanks in advance 😃

jrzaurin commented 1 month ago

Thanks man

As before, try installing the branch and running the code again, should work

altar31 commented 1 month ago

Perfect ! Now I can set the va_split.

As a consequence, early stopping and LR scheduler perfectly works too! 👍

Thanks @jrzaurin