keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.01k stars 19.48k forks source link

Backend-Agnostic Types #19230

Open LarsKue opened 8 months ago

LarsKue commented 8 months ago

It would be nice if keras could provide backend-agnostic types that we can use to type-hint libraries built on top of keras. Some that come to mind are:

The following snippet works for the Tensor type, even from outside keras:

# this is ugly, but:
# 1. it is recognized by static type checkers (not possible with if-else branching)
# 2. it does not leave the Tensor type possibly undefined (not possible without nesting)
try:
    import jax
    Tensor = jax.Array
except ModuleNotFoundError:
    try:
        import tensorflow as tf
        Tensor = tf.Tensor
    except ModuleNotFoundError:
        import torch
        Tensor = torch.Tensor

For reference, this is the issue I am referring to (exemplified in PyCharm 2023.3.4):

Screenshot from 2024-02-27 12-02-39

LarsKue commented 8 months ago

I should probably add: the already existing KerasTensor is not sufficient for type-hinting. Consider the following:

# lib.py
import keras

def f(x: keras.KerasTensor):
    pass
# user.py
import tensorflow as tf

y: tf.Tensor = tf.constant(1)

f(y)  # Expected type 'KerasTensor', got 'Tensor' instead
SuryanarayanaY commented 8 months ago

Hi @LarsKue ,

A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions.

I tried though to reproduce the issue in colab but couldn't.Please refer attached gist.

Could you please check and confirm on how to reproduce the issue.

LarsKue commented 8 months ago

I think you may have misunderstood my example. A static type checker will display the error, but running it works fine of course (the hints are just hints after all).

For reference, I am using PyCharm's built-in type checker, but I assume other type checkers will yield similar results.

SuryanarayanaY commented 8 months ago

Hi @LarsKue , I think I got it now. You might be getting the reported error in IDE.Could you please add a snapshot?

LarsKue commented 8 months ago

@SuryanarayanaY I added a screenshot to the original post.

divyashreepathihalli commented 8 months ago

@LarsKue, would the following approach work for your use case? It would work if you don't use type annotations.

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf

def f(x):
  return ops.floor(x)

y: tf.Tensor = tf.constant(1.37865)
print(f(y))
LarsKue commented 8 months ago

@divyashreepathihalli Not using type annotations does not seem like a fix to me. It just circumvents the issue by implicitly saying that type(x) is Any, which turns off the type checker.

Type annotations can vastly improve the usability of library code, particularly when multiple types are allowed for an argument as is often the case for convenience or utility functions, or there exists no meaningful name for the argument as is often the case for arguments like input or x.

divyashreepathihalli commented 8 months ago

I did try it with type annotation and it works too. Here is an example code snippet

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf

def func(x: keras.KerasTensor):
  return ops.floor(x)

y: tf.Tensor = tf.constant(1.37865)
print(func(y))

image

LarsKue commented 8 months ago

@divyashreepathihalli I assume you are using Colab? In that case, this would be because Colab does not have any form of static type checking. Consider e.g. the following, where no type checking warning or error is raised:

Screenshot from 2024-02-29 11-44-36

LarsKue commented 8 months ago

@divyashreepathihalli I found an option to turn on type checking in Colab:

Settings > Editor > Code Diagnostics

Change to "Syntax and type checking"

This yields an equivalent issue: Screenshot from 2024-02-29 11-59-32

divyashreepathihalli commented 8 months ago

@LarsKue Adding this support in Keras would be very complicated. The recommendation from the team is to create your own custom type for your code.

fchollet commented 8 months ago

Your custom type should include everything you want to be recognized as a tensor, which is dependent on your use case. For instance, it could be the union of tf.Tensor, JAX Array, torch Tensor, and KerasTensor -- but it could also include more, like Variables (which can be passed to pretty much any function that takes tensors), NumPy arrays (same), TF IndexedSlices, etc.

LarsKue commented 8 months ago

I understand that it is nontrivial, which is why this is exactly something that keras should implement, not the user.

If you need simplification, I would argue that the three basic tensor types tf.Tensor, torch.Tensor, jax.Array are sufficient for likely 99% of use cases. Any other functionality can be considered backend-dependent. In that case, the code snippet I already provided would be sufficient for the Tensor type.

What about the other types, like Shape? Surely this one is not hard, e.g. tuple[int, ...] | torch.Size. I am a little bit unfamiliar with jax, but the Distribution type also already exists in both torch and tensorflow_probability. There are likely many more shared types between the three backends that could do with a keras abstraction.

mthiboust commented 5 months ago

You may be interested in this jaxtyping issue about multi-backend tensor support: https://github.com/patrick-kidger/jaxtyping/issues/168