Open areiner222 opened 1 year ago
Thanks for the suggestion. We are looking into this. The key APIs to modify would be is_tensor
, convert_to_tensor
, convert_to_numpy
. Maybe we can just extend those on the TF and JAX side.
Any progress on this matter? It would be fantastic to have the extension types work with Keras 3 :)
I've heavily relied on using structured inputs for subclassed {Model, Layer}.call - will keras 3 support this?
I seem to be unable to pass a tensorflow ExtensionType or a generic dataclass (PyTreeNode in jax) hitting this value check.
I believe it should be possible to pass this kind of structured input especially with the tf_flatten / tf_unflatten utility and the jax pytree registration functionality.
TF extension type example: