Open LarsKue opened 8 months ago
I should probably add: the already existing KerasTensor
is not sufficient for type-hinting. Consider the following:
# lib.py
import keras
def f(x: keras.KerasTensor):
pass
# user.py
import tensorflow as tf
y: tf.Tensor = tf.constant(1)
f(y) # Expected type 'KerasTensor', got 'Tensor' instead
Hi @LarsKue ,
A KerasTensor
is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional
models or Keras Functions.
I tried though to reproduce the issue in colab but couldn't.Please refer attached gist.
Could you please check and confirm on how to reproduce the issue.
I think you may have misunderstood my example. A static type checker will display the error, but running it works fine of course (the hints are just hints after all).
For reference, I am using PyCharm's built-in type checker, but I assume other type checkers will yield similar results.
Hi @LarsKue , I think I got it now. You might be getting the reported error in IDE.Could you please add a snapshot?
@SuryanarayanaY I added a screenshot to the original post.
@LarsKue, would the following approach work for your use case? It would work if you don't use type annotations.
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf
def f(x):
return ops.floor(x)
y: tf.Tensor = tf.constant(1.37865)
print(f(y))
@divyashreepathihalli Not using type annotations does not seem like a fix to me. It just circumvents the issue by implicitly saying that type(x) is Any
, which turns off the type checker.
Type annotations can vastly improve the usability of library code, particularly when multiple types are allowed for an argument as is often the case for convenience or utility functions, or there exists no meaningful name for the argument as is often the case for arguments like input
or x
.
I did try it with type annotation and it works too. Here is an example code snippet
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf
def func(x: keras.KerasTensor):
return ops.floor(x)
y: tf.Tensor = tf.constant(1.37865)
print(func(y))
@divyashreepathihalli I assume you are using Colab? In that case, this would be because Colab does not have any form of static type checking. Consider e.g. the following, where no type checking warning or error is raised:
@divyashreepathihalli I found an option to turn on type checking in Colab:
Settings > Editor > Code Diagnostics
Change to "Syntax and type checking"
This yields an equivalent issue:
@LarsKue Adding this support in Keras would be very complicated. The recommendation from the team is to create your own custom type for your code.
Your custom type should include everything you want to be recognized as a tensor, which is dependent on your use case. For instance, it could be the union of tf.Tensor, JAX Array, torch Tensor, and KerasTensor -- but it could also include more, like Variables (which can be passed to pretty much any function that takes tensors), NumPy arrays (same), TF IndexedSlices, etc.
I understand that it is nontrivial, which is why this is exactly something that keras
should implement, not the user.
If you need simplification, I would argue that the three basic tensor types tf.Tensor, torch.Tensor, jax.Array
are sufficient for likely 99% of use cases. Any other functionality can be considered backend-dependent. In that case, the code snippet I already provided would be sufficient for the Tensor
type.
What about the other types, like Shape
? Surely this one is not hard, e.g. tuple[int, ...] | torch.Size
. I am a little bit unfamiliar with jax
, but the Distribution
type also already exists in both torch
and tensorflow_probability
. There are likely many more shared types between the three backends that could do with a keras
abstraction.
You may be interested in this jaxtyping
issue about multi-backend tensor support: https://github.com/patrick-kidger/jaxtyping/issues/168
It would be nice if
keras
could provide backend-agnostic types that we can use to type-hint libraries built on top ofkeras
. Some that come to mind are:The following snippet works for the
Tensor
type, even from outsidekeras
:For reference, this is the issue I am referring to (exemplified in PyCharm 2023.3.4):