DataCanvasIO / DeepTables

DeepTables: Deep-learning Toolkit for Tabular data
https://deeptables.readthedocs.io
Apache License 2.0
659 stars 117 forks source link

Deserialization of a DeepTables model with a custom objective fails #88

Open deburky opened 11 months ago

deburky commented 11 months ago

System information

Describe the current behavior

When trying to deserialize a DeepTables model with a different objective function, deserialization fails due to custom loss function being not defined. A potential workaround could be:

tf.keras.models.save_model(model.model, 'model.h5')
model2 = tf.keras.models.load_model('model.h5', custom_objects={
    'MultiColumnEmbedding' : MultiColumnEmbedding,
    'FM': FM,
    'FocalLoss': FocalLoss})

However, using DeepTables with this object is not straightforward and documentation doesn't cover this aspect.

Describe the expected behavior

It would be possible to deserialize any DeepTable model containing custom objectives without any effort. It is specifically important for deployment (for example, no clear way to use mlem package for example).

Standalone code to reproduce the issue

Link to Colab notebook

Let's initiate a virtual environment and install needed packages:

python3.11 -m venv .venv
source .venv/bin/activate
pip install pandas numpy scikit-learn tensorflow deeptables dask jupyter

After this, we can run the code below reproducing this issue.

"""
Original file is located at
    https://colab.research.google.com/drive/1XECkcRpqYCPlgRLCuqPn0BDFwKx8ujzc
"""

import tensorflow as tf
from tensorflow.keras.losses import Loss
from deeptables.models import deeptable, deepnets
from deeptables.datasets import dsutils
import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

class FocalLoss(Loss):
    """
    We want to maximize the likelihood of correctly classifying challenging
    examples while giving less emphasis to well-classified examples.
    """
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())

        # Cast y_true to float32
        y_true = tf.cast(y_true, tf.float32)

        ce_loss = - (y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
        pt = tf.math.exp(ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return tf.reduce_mean(focal_loss, axis=-1)

# Generate a synthetic dataset
X, y = make_classification(n_samples=10000, n_features=5, n_classes=2, n_informative=3, random_state=42)
X_trn, X_tst, y_trn, y_tst = train_test_split(X, y, stratify=y, test_size=0.33, random_state=62)

config = deeptable.ModelConfig(
    nets=deepnets.DeepFM,
    loss=FocalLoss(alpha=0.1, gamma=0.0),
    metrics=["AUC"],
    auto_discrete=True
)
dt_fl = deeptable.DeepTable(config=config)
model_fl, history_fl = dt_fl.fit(X_trn, y_trn, epochs=10)

result = dt_fl.evaluate(X_tst,y_tst, batch_size=512, verbose=0)
print(result)

preds_fl = dt_fl.predict_proba(X_tst)

import tempfile
tmpdir = tempfile.mkdtemp()
dt_fl.save(tmpdir)
model_load = deeptable.DeepTable.load(tmpdir)
model_load.evaluate(X_tst, y_tst)

The output produced is:

ValueError: Unknown loss function: 'FocalLoss'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

Your support / feedback on this will be appreciated.

oaksharks commented 2 months ago

Sorry for replying so late. As you demonstrated, we can now deserialize it with latest code on master branch like this:

import tempfile
tmpdir = tempfile.mkdtemp()
dt_fl.save(tmpdir)
model_load = deeptable.DeepTable.load(tmpdir, custom_objects={'FocalLoss': FocalLoss})  # KEYPOINT
model_load.evaluate(X_tst, y_tst)

But in you case, __init__ of FocalLoss should be added **kwargs:

class FocalLoss(Loss):
    """
    We want to maximize the likelihood of correctly classifying challenging
    examples while giving less emphasis to well-classified examples.
    """
    def __init__(self, alpha=0.25, gamma=2.0, **kwargs):  # KEYPOINT
        super(FocalLoss, self).__init__(**kwargs)
        self.alpha = alpha
        self.gamma = gamma