Closed lbortolotti closed 10 months ago
Thanks for the report.
First of all, do not use tensor methods like x.reshape()
since your x
is always a backend-native tensor. Instead, use ops.reshape(x, shape)
. Alternatively, if you want to enable tensor methods on TensorFlow tensors, you can do so via this utility, but I would generally not recommend doing that. Much cleaner to stick to ops
functions.
Second, indeed it seems that the broadcasting behavior of numpy/jax differs from TensorFlow. We'll look into it.
Hi @fchollet,
Yes, sorry you've nailed the core issue on the head. I somehow managed to omit it from the original issue...
For completeness, with the corrected code:
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
from keras_core import ops
A = ops.reshape(ops.numpy.arange(10), (2, 5))
b = ops.numpy.arange(5)
out = A @ b
print(out.shape)
The error becomes:
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:CPU:0}} In[0] and In[1] has different ndims: [2,5] vs. [5] [Op:MatMul] name:
While with the jax backend, broadcasting is applied correctly.
An update on this: this is because when you use @
you will be calling into tensor methods. You should not (again) use tensor methods.
If you do out = ops.matmul(A, b)
instead, broadcasting works the same in all backends and your code snippet runs.
The following code works correctly with the jax backend:
While, with the tensorflow backend:
This is similar to the error identified in the second part of issue https://github.com/keras-team/keras-core/issues/919.
Package versions: