keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Matmul - tensorflow does not broadcast/expand dimensions correctly #920

Closed lbortolotti closed 10 months ago

lbortolotti commented 10 months ago

The following code works correctly with the jax backend:

import os

os.environ['KERAS_BACKEND'] = 'jax'
from keras_core import ops

A = ops.numpy.arange(10).reshape(2, 5)
b = ops.numpy.arange(5)

out = A @ b

print(out.shape)

While, with the tensorflow backend:

AttributeError: EagerTensor object has no attribute 'reshape'. 
        If you are looking for numpy-related methods, please run the following:
        from tensorflow.python.ops.numpy_ops import np_config
        np_config.enable_numpy_behavior()
      . Did you mean: '_shape'?

This is similar to the error identified in the second part of issue https://github.com/keras-team/keras-core/issues/919.

Package versions:

jax==0.4.14
jaxlib==0.4.14
keras-core==0.1.5
numpy==1.24.3
tensorflow==2.13.0
fchollet commented 10 months ago

Thanks for the report.

First of all, do not use tensor methods like x.reshape() since your x is always a backend-native tensor. Instead, use ops.reshape(x, shape). Alternatively, if you want to enable tensor methods on TensorFlow tensors, you can do so via this utility, but I would generally not recommend doing that. Much cleaner to stick to ops functions.

Second, indeed it seems that the broadcasting behavior of numpy/jax differs from TensorFlow. We'll look into it.

lbortolotti commented 10 months ago

Hi @fchollet,

Yes, sorry you've nailed the core issue on the head. I somehow managed to omit it from the original issue...

For completeness, with the corrected code:

import os

os.environ['KERAS_BACKEND'] = 'tensorflow'
from keras_core import ops

A = ops.reshape(ops.numpy.arange(10), (2, 5))

b = ops.numpy.arange(5)

out = A @ b

print(out.shape)

The error becomes:

tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:CPU:0}} In[0] and In[1] has different ndims: [2,5] vs. [5] [Op:MatMul] name: 

While with the jax backend, broadcasting is applied correctly.

fchollet commented 10 months ago

An update on this: this is because when you use @ you will be calling into tensor methods. You should not (again) use tensor methods.

If you do out = ops.matmul(A, b) instead, broadcasting works the same in all backends and your code snippet runs.