import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras_core as keras
import numpy as np
X = np.random.random((1024, 100, 1))
Y = np.random.random((1024, 100, 1))
model = keras.Sequential([keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(1, activation='relu'),
])
model.summary()
model.compile(optimizer=keras.optimizers.Adam(amsgrad=True), loss=keras.losses.mean_squared_error)
model.fit(X, Y)
model.save('repro_1.keras')
Throws the following error:
Traceback (most recent call last):
File "C:\lbortolotti\PerformanceMethods\keras_core_experiments\repro_1.py", line 18, in <module>
model.fit(X, Y)
File "C:\lbortolotti\PerformanceMethods\keras_core_experiments\venv\lib\site-packages\keras_core\src\utils\traceback_utils.py", line 123, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\lbortolotti\PerformanceMethods\keras_core_experiments\venv\lib\site-packages\keras_core\src\backend\jax\numpy.py", line 375, in maximum
return jnp.maximum(x1, x2)
TypeError: Argument '<KerasVariable shape=(1, 128), dtype=float32, path=adam/sequential_dense_kernel_velocity_hat>' of type <class 'keras_core.src.backend.jax.core.Variable'> is not a valid JAX type.
Switching to tensorflow backend / setting amsgrad=False works around the issue.
The following code:
Throws the following error:
Switching to tensorflow backend / setting amsgrad=False works around the issue.
Package versions: