keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Inconsistent type handling between backends #919

Closed lbortolotti closed 10 months ago

lbortolotti commented 10 months ago

The following code works correctly with the jax backend:

import os

os.environ['KERAS_BACKEND'] = 'jax'
from keras_core import ops

r1 = ops.numpy.arange(10, dtype='float32')

c1 = r1 + 1j * r1

print(c1)

Switching to tensorflow backend:

import os

os.environ['KERAS_BACKEND'] = 'tensorflow'
from keras_core import ops

r1 = ops.numpy.arange(10, dtype='float32')

c1 = r1 + 1j * r1

print(c1)

Throws: TypeError: Cannot convert 1j to EagerTensor of dtype float

Potentially related, the following runs OK with the jax backend:

import os

os.environ['KERAS_BACKEND'] = 'jax'
from keras_core import ops

r1 = ops.numpy.arange(10).astype('float32')

While the tensorflow backend:

import os

os.environ['KERAS_BACKEND'] = 'tensorflow'
from keras_core import ops

r1 = ops.numpy.arange(10).astype('float32')

Throws:

AttributeError: EagerTensor object has no attribute 'astype'. 
        If you are looking for numpy-related methods, please run the following:
        from tensorflow.python.ops.numpy_ops import np_config
        np_config.enable_numpy_behavior()
      . Did you mean: 'dtype'?

Running enable_numpy_behavior() does resolve the issue, but I don't like the idea of having to modify the default TF behaviour to make it work. I assume that, if the developer goes via the keras_core.ops API, the same code should work with both backends as-is?

Package versions:

jax==0.4.14
jaxlib==0.4.14
keras-core==0.1.5
numpy==1.24.3
tensorflow==2.13.0

Thanks,

Luca

fchollet commented 10 months ago

Like in the other issue linked, this is caused by the use of tensor methods: which methods are available and what they do depends on the backend framework because your tensors are backend-native tensors.

In order to get a fully standardized API surface, don't use tensor methods and instead us keras.ops, e.g. ops.add or ops.cast.

Do note that using Python operators such as + or @ will call into the object's methods. It will not call into keras.ops.