Closed jacoverster closed 1 year ago
Running the example with the JAX backend:
os.environ["KERAS_BACKEND"] = "jax"
Produces this keras.ops.one_hot related error. I'm not familiar enough with JAX to comment, but might be of interest to the keras_core team.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-3-7e84aa32f170>](https://localhost:8080/#) in <cell line: 30>()
28 split=['train', 'test'],
29 )
---> 30 train_dataset = train_dataset.batch(BATCH_SIZE).map(
31 lambda x, y: preprocess_data(x, y, augment=True),
32 num_parallel_calls=tf.data.AUTOTUNE).prefetch(
34 frames
[... skipping hidden 8 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py](https://localhost:8080/#) in _canonicalize_dtype(x64_enabled, allow_opaque_dtype, dtype)
146 dtype_ = np.dtype(dtype)
147 except TypeError as e:
--> 148 raise TypeError(f'dtype {dtype!r} not understood') from e
149
150 if x64_enabled:
TypeError: in user code:
File "<ipython-input-3-7e84aa32f170>", line 31, in None *
lambda x, y: preprocess_data(x, y, augment=True)
File "<ipython-input-3-7e84aa32f170>", line 20, in preprocess_data *
labels = keras.ops.one_hot(labels, NUM_CLASSES)
File "/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/nn.py", line 965, in one_hot *
x, num_classes, axis=axis, dtype=dtype or backend.floatx()
File "/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/nn.py", line 391, in one_hot *
return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/nn/functions.py", line 466, in one_hot *
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback **
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 491, in common_infer_params
avals.append(shaped_abstractify(a))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 565, in shaped_abstractify
return _shaped_abstractify_slow(x)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 552, in _shaped_abstractify_slow
dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 163, in canonicalize_dtype
return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 148, in _canonicalize_dtype
raise TypeError(f'dtype {dtype!r} not understood') from e
TypeError: dtype tf.int64 not understood
@jacoverster thanks for the report!
The Jax issue is likely because you're using a Keras Core op in a TF-only preprocessing workflow:
def preprocess_data(images, labels, augment=False):
# labels = tf.one_hot(labels, NUM_CLASSES) <--- CHANGE ---
labels = keras.ops.one_hot(labels, NUM_CLASSES)
inputs = {"images": images, "labels": labels}
outputs = augmenter(inputs) if augment else inputs
return outputs['images'], outputs['labels']
The preprocessing pipeline here is running in tf.data
, so it should use pure TF ops internally.
I suspect if you switch back to tf.one_hot this error will ago away.
With respect to the error with the TF backend -- this is less obvious to me. Some things I'd try are:
tf.one_hot
fix.I will take a deeper look at this as soon as I have a chance -- thanks again for the thorough report!
@ianstenbit sadly using tf.one_hot
does not fix the issue. I get error (gist):
NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
My guess is that KerasCV using JAX is creating JAX augmentation layers, which cannot consume a tf.Tensor
. My guess is combining tf.data
with other backends could be tricky unless you try to jugging both Keras Core and tf.keras
.
I had zero problem using a TensorFlow backend (gist)
Thanks for the quick response. The JAX backend works with the changes, but model.fit() still fails with the same error as above. I tested it using TensorFlow backend with your gist.
ValueError: Exception encountered when calling FusedMBConvBlock.call().
tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Arguments received by FusedMBConvBlock.call():
• inputs=tf.Tensor(shape=(None, 75, 75, 32), dtype=float32)
@jacoverster this looks like an issue with the MBConvBlock's dropout layer. I will send a fix in a moment.
This will be fixed by https://github.com/keras-team/keras-cv/pull/1951
Great work, thanks guys.
I made small changes to the quickstart example for compatibility with the keras_core API (marked below).
Happy to contribute to a fix with some guidance.
The error is as follows:
Code to reproduce: