keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.47k forks source link

Enabling structured inputs to call for Keras 3 #18735

Open areiner222 opened 1 year ago

areiner222 commented 1 year ago

I've heavily relied on using structured inputs for subclassed {Model, Layer}.call - will keras 3 support this?

I seem to be unable to pass a tensorflow ExtensionType or a generic dataclass (PyTreeNode in jax) hitting this value check.

I believe it should be possible to pass this kind of structured input especially with the tf_flatten / tf_unflatten utility and the jax pytree registration functionality.

TF extension type example:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras_core

class CompositeTensor(tf.experimental.ExtensionType):
    value: tf.Tensor
    meta: int

    def __tf_flatten__(self):
        metadata = (self.meta,)  # static config.
        components = (self.value,)  # dynamic values.
        return metadata, components

    @classmethod
    def __tf_unflatten__(cls, metadata, components):
        return cls(*metadata, *components)

class ModelCheck(keras_core.Model):
    def __init__(self):
        super().__init__()
        self.layer = keras_core.layers.Dense(32)

    def call(inp, training=None):
        return self.layer(inp.value)

m = ModelCheck()

inp = CompositeTensor(value=tf.random.uniform((10, 64)), meta=3)
print([type(v) for v in tf.nest.flatten(inp)])
out = m(inp)
fchollet commented 12 months ago

Thanks for the suggestion. We are looking into this. The key APIs to modify would be is_tensor, convert_to_tensor, convert_to_numpy. Maybe we can just extend those on the TF and JAX side.

akensert commented 6 months ago

Any progress on this matter? It would be fantastic to have the extension types work with Keras 3 :)