Open deburky opened 11 months ago
Sorry for replying so late. As you demonstrated, we can now deserialize it with latest code on master branch like this:
import tempfile
tmpdir = tempfile.mkdtemp()
dt_fl.save(tmpdir)
model_load = deeptable.DeepTable.load(tmpdir, custom_objects={'FocalLoss': FocalLoss}) # KEYPOINT
model_load.evaluate(X_tst, y_tst)
But in you case, __init__
of FocalLoss
should be added **kwargs
:
class FocalLoss(Loss):
"""
We want to maximize the likelihood of correctly classifying challenging
examples while giving less emphasis to well-classified examples.
"""
def __init__(self, alpha=0.25, gamma=2.0, **kwargs): # KEYPOINT
super(FocalLoss, self).__init__(**kwargs)
self.alpha = alpha
self.gamma = gamma
System information
keras==2.14.0 ml-dtypes==0.2.0 tensorboard==2.14.1 tensorflow-estimator==2.14.0 tensorflow-macros==2.14.0
Describe the current behavior
When trying to deserialize a DeepTables model with a different objective function, deserialization fails due to custom loss function being not defined. A potential workaround could be:
However, using DeepTables with this object is not straightforward and documentation doesn't cover this aspect.
Describe the expected behavior
It would be possible to deserialize any DeepTable model containing custom objectives without any effort. It is specifically important for deployment (for example, no clear way to use
mlem
package for example).Standalone code to reproduce the issue
Link to Colab notebook
Let's initiate a virtual environment and install needed packages:
After this, we can run the code below reproducing this issue.
The output produced is:
Your support / feedback on this will be appreciated.