keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 116 forks source link

Is possible serialize models which use torch functions? #903

Closed emi-dm closed 1 year ago

emi-dm commented 1 year ago

Would it be possible to serialize loss functions and torch metrics in a keras model?

I think it would be quite a powerful feature!

Note: I have provided the loss function in the compile method of a subclass of Keras model and have altered the order of logits in compute_loss function:

        y_pred = self(x, training=True)  # Forward pass , y_pred is logits
        probs = keras.activations.softmax(y_pred)  # Compute the probabilities for the calc of metrics

        loss = self.compute_loss(y=y_pred, y_pred=y)  # Compute the loss with logits (from_logits=True)

Thanks in advanced!!

image

fchollet commented 1 year ago

Not yet, but we want to add this feature. Here's a related issue: https://github.com/keras-team/keras/issues/18403

emi-dm commented 1 year ago

It would be great!!! Thanks @fchollet :)

emi-dm commented 1 year ago

Hi everyone!!! Here is a colab for reproduce the issue: https://colab.research.google.com/drive/1kFQ4_2qPlHm4P63Qb2ry-lC0OQ_zY0RJ?usp=sharing

emi-dm commented 1 year ago

@fchollet I thought maybe you could make a class wrapper that wraps the base class of torch.nn to implement the get_config and from_config methods, so that you could serialize. I don't have much knowledge of the internal implementation of Keras and I don't know if I could...

Thanks in advanced guys!

qlzh727 commented 1 year ago

Closing this as a duplication against https://github.com/keras-team/keras/issues/18403. we will track the fix there.