keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.61k stars 19.42k forks source link

Issue with Backpropagation on keras.ops.while_loop (in Jax) #18957

Open Jacobiano opened 9 months ago

Jacobiano commented 9 months ago

I have implemented a layer to calculate the morphology reconstruction. When it is used without being considered in the backprop, it works fine in all three backends. But when the backprop has to be calculated, in the TF or Pytorch backend it works without problem, but not in JAX.


import os
#os.environ["KERAS_BACKEND"] = "jax"
#os.environ["KERAS_BACKEND"] = "torch"
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
print(keras.__version__)  #3.0.1

from keras.layers import Layer

def condition_equal(last,new,image):
    return keras.ops.logical_not(keras.ops.all(keras.ops.equal(last, new)))

def update_dilation(last,new,mask):
     return [new, geodesic_dilation_step([new, mask]), mask]

def geodesic_dilation_step(X):
    """
    1 step of reconstruction by dilation
    :X tensor: X[0] is the Mask and X[1] is the Image
    :param steps: number of steps (by default NUM_ITER_REC)
    :Example:
    >>>Lambda(geodesic_dilation_step, name="reconstruction")([Mask,Image])
    """
    # perform a geodesic dilation with X[0] as marker, and X[1] as mask
    return keras.ops.minimum(keras.layers.MaxPooling2D(pool_size=(3, 3),strides=(1,1),padding='same')(X[0]),X[1])

@keras.saving.register_keras_serializable()
class GeodesicalReconstructionLayer(Layer):
    def __init__(self,steps=None):
        super().__init__()
        self.steps = steps

    def call(self, inputs):
        rec = inputs[0]
        rec = geodesic_dilation_step([rec, inputs[1]])
        _, rec,_=keras.ops.while_loop(condition_equal, update_dilation, [inputs[0], rec, inputs[1]], maximum_iterations=self.steps)
        return rec

BATCH_SIZE = 32
EPOCHS = 2

num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

NFILTERS=12
STEPS=10
BATCH_SIZE = 32
EPOCHS = 2

## The keras.ops.while_loop works perfect in TF,JAX or Pytorch in this case
x0=keras.layers.Input(shape=input_shape)
xrec=GeodesicalReconstructionLayer(steps=10)([x0-.1,x0])
xf=keras.layers.Conv2D(NFILTERS,kernel_size=(3,3),activation='relu')(xrec)
xf=keras.layers.Flatten()(xf)
xf=keras.layers.Dense(num_classes,activation='softmax')(xf)

model=keras.Model(x0,xf)
model.compile(loss="SparseCategoricalCrossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=(x_test,y_test))

xini=keras.layers.Input(shape=input_shape)
xconv=keras.layers.Conv2D(1,kernel_size=(3,3),kernel_initializer='ones',activation='relu')(xini)
xrec=GeodesicalReconstructionLayer(steps=10)([xconv-.1,xconv])
xf=keras.layers.Conv2D(NFILTERS,kernel_size=(3,3),activation='relu')(xrec)
xf=keras.layers.Flatten()(xf)
xf=keras.layers.Dense(num_classes,activation='softmax')(xf)

modelwithBP=keras.Model(xini,xf)
modelwithBP.compile(loss="SparseCategoricalCrossentropy", optimizer="adam", metrics=["accuracy"])
modelwithBP.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=(x_test,y_test))

JAX ERROR: -> 1549 raise ValueError("Reverse-mode differentiation does not work for " 1550 "lax.while_loop or lax.fori_loop with dynamic start/stop values. " 1551 "Try using lax.scan, or using fori_loop with static start/stop.")

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

The code is available in (https://colab.research.google.com/drive/1bWQO6TAQeN_-a0y6iY7b_jlnGzv-XRdv?usp=sharing)

sachinprasadhs commented 9 months ago

Hi,

Thanks for reporting the issue. This issue seems to be more specific to Jax, have you tried the solution as suggested in the error message?

Jacobiano commented 9 months ago

I think it is possible to implement in jax in another way. But I report this issue, because in the idea of having multiple backend, I found it strange that jax does not manage to use well the while_loop function.

AakashKumarNain commented 9 months ago

I think neither JAX nor keras is at fault here. The while_loop works as expected but there are certain limitations to it. For JAX, it is always advisable to use scan because it is much better in almost every aspect.