keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Adam with amsgrad=True + JAX backend is broken #918

Closed lbortolotti closed 10 months ago

lbortolotti commented 10 months ago

The following code:

import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras_core as keras
import numpy as np

X = np.random.random((1024, 100, 1))
Y = np.random.random((1024, 100, 1))

model = keras.Sequential([keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dense(64, activation='relu'),
                          keras.layers.Dense(1, activation='relu'),
                          ])

model.summary()

model.compile(optimizer=keras.optimizers.Adam(amsgrad=True), loss=keras.losses.mean_squared_error)

model.fit(X, Y)

model.save('repro_1.keras')

Throws the following error:

Traceback (most recent call last):
  File "C:\lbortolotti\PerformanceMethods\keras_core_experiments\repro_1.py", line 18, in <module>
    model.fit(X, Y)
  File "C:\lbortolotti\PerformanceMethods\keras_core_experiments\venv\lib\site-packages\keras_core\src\utils\traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\lbortolotti\PerformanceMethods\keras_core_experiments\venv\lib\site-packages\keras_core\src\backend\jax\numpy.py", line 375, in maximum
    return jnp.maximum(x1, x2)
TypeError: Argument '<KerasVariable shape=(1, 128), dtype=float32, path=adam/sequential_dense_kernel_velocity_hat>' of type <class 'keras_core.src.backend.jax.core.Variable'> is not a valid JAX type.

Switching to tensorflow backend / setting amsgrad=False works around the issue.

Package versions:

jax==0.4.14
jaxlib==0.4.14
keras-core==0.1.5
numpy==1.24.3
tensorflow==2.13.0
fchollet commented 10 months ago

Thanks for the report. This is now fixed at HEAD.