keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Unusual behavior of `predict` for JAX backend #19639

Closed AakashKumarNain closed 2 weeks ago

AakashKumarNain commented 2 weeks ago

Here is a simple jitted function doing something useful

@partial(jax.jit, static_argnums=(1, 2))
def get_array_slice_and_repeat(x, pos, num_repeats):
    x_slice = x[:, :pos, :, :]
    x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
    return x_slice

# This works
x = jnp.asarray(np.random.rand(1, 5, 4, 3))
y = get_array_slice_and_repeat(x, 2, 3)

Let's say I have a keras model that needs to use that operation. Keeping this example very simple:

class CustomModel(keras.Model):
    def __init__(self, name=None):
        super().__init__()
        self.name = name if name is not None else "custom_model"

    def call(self, inputs, training=True):
        x, pos = inputs
        x = get_array_slice_and_repeat(x, pos.item(), 2)
        return x

x = np.random.rand(1, 5, 4, 3)
pos = jnp.array([2])
model = CustomModel()
out = model.predict((x, pos))

This fails with following error:

---------------------------------------------------------------------------
---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
Cell In[34], line 1
----> 1 out = model.predict((x, pos))

File /keras/src/utils/traceback_utils.py:113, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    110 @wraps(fn)
    111 def error_handler(*args, **kwargs):
    112     if not is_traceback_filtering_enabled():
--> 113         return fn(*args, **kwargs)
    115     filtered_tb = None
    116     try:

File /keras/src/backend/jax/trainer.py:675, in JAXTrainer.predict(self, x, batch_size, verbose, steps, callbacks)
    673 else:
    674     state = (state[0], non_trainable_variables)
--> 675 batch_outputs, non_trainable_variables = self.predict_function(
    676     state, x
    677 )
    678 outputs = append_to_outputs(batch_outputs, outputs)
    679 callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})

    [... skipping hidden 12 frame]

File /keras/src/backend/jax/trainer.py:322, in JAXTrainer.make_predict_function.<locals>.compiled_predict_step(state, data)
    320 @jax.jit
    321 def compiled_predict_step(state, data):
--> 322     return predict_step(state, data)

File /keras/src/backend/jax/trainer.py:296, in JAXTrainer.make_predict_function.<locals>.one_predict_step(state, data)
    294 def one_predict_step(state, data):
    295     data = data[0]
--> 296     return self.predict_step(state, data)

File /keras/src/backend/jax/trainer.py:210, in JAXTrainer.predict_step(self, state, data)
    207     kwargs["training"] = False
    209 x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
--> 210 outputs, non_trainable_variables = self.stateless_call(
    211     trainable_variables, non_trainable_variables, x, **kwargs
    212 )
    213 (
    214     _,
    215     non_trainable_variables,
   (...)
    222     metrics_variables=None,
    223 )
    224 return outputs, non_trainable_variables

File /keras/src/utils/traceback_utils.py:113, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    110 @wraps(fn)
    111 def error_handler(*args, **kwargs):
    112     if not is_traceback_filtering_enabled():
--> 113         return fn(*args, **kwargs)
    115     filtered_tb = None
    116     try:

File /keras/src/layers/layer.py:975, in Layer.stateless_call(self, trainable_variables, non_trainable_variables, return_losses, *args, **kwargs)
    973     outputs = self.quantized_call(*args, **kwargs)
    974 else:
--> 975     outputs = self.call(*args, **kwargs)
    976 if return_losses:
    977     losses = self.losses

Cell In[32], line 8, in CustomModel.call(self, inputs, training)
      6 def call(self, inputs, training=True):
      7     x, pos = inputs
----> 8     x = get_array_slice_and_repeat(x, pos.item(), 2)
      9     return x

    [... skipping hidden 1 frame]

File /jax/_src/numpy/array_methods.py:76, in _item(a, *args)
     74 def _item(a: Array, *args) -> bool | int | float | complex:
     75   """Copy an element of an array to a standard Python scalar and return it."""
---> 76   arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array")
     77   if dtypes.issubdtype(a.dtype, dtypes.extended):
     78     raise TypeError(f"No Python scalar type for {a.dtype=}")

File /jax/_src/core.py:1509, in concrete_or_error(force, val, context)
   1507     return force(val.aval.val)
   1508   else:
-> 1509     raise ConcretizationTypeError(val, context)
   1510 else:
   1511   return force(val)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[1].
This occurred in the item() method of jax.Array
The error occurred while tracing the function compiled_predict_step at /keras/src/backend/jax/trainer.py:320 for jit. This concrete value was not available in Python because it depends on the value of the argument data[0][0][1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
AakashKumarNain commented 2 weeks ago

I know that the whole call is jitted when we call predict, but this seems to put a restriction if we want to have some functionality inside the call that requires to be jitted explicitly with some static arguments

fchollet commented 2 weeks ago

The documentation for x.item() in JAX says:

Copy an element of an array to a standard Python scalar and return it.

It is not possible to access the eager value of an array inside a jitted function, which is why calling item() inside call() fails with predict() -- as you note, predict() is itself jitted.

You have two possible solutions:

AakashKumarNain commented 2 weeks ago

@fchollet Thanks for the quick response.

Yes, I am aware of item(). Here is the problem though: If I use predict, then all inputs needs to have same dimensionality. If I don't use item() or int(pos[0]) or int(ops.get_item(pos, -1)), there is no way to retrieve the position as integer, and hence not possible to jit the get_array_slice_and_repeat(...) method as value of pos passed would be an object of array_impl which isn't hashable.

Unfortunately, running eagerly is very slow for my use case.

Note: I also don't see any way to use ops.slice(...) in this case as it will complain about the Tracer in the shape values i.e. ops.slice(x, [0, 0, 0, 0], [x_shape[0], pos+1, x_shape[2], x_shape[3]])

fchollet commented 2 weeks ago

You can try asking the question to the JAX team to see what they recommend in this case (basically, how to use your get_array_slice_and_repeat inside a jit function). There might be a factoring that removes the need for the eager item() call. If not, then it's just a hard limitation of XLA.

hertschuh commented 2 weeks ago

@AakashKumarNain ,

Is the pos argument static? Or is it dynamic per sample? Could it be a configuration of the model instead of an input?

Thanks, Fabien

AakashKumarNain commented 2 weeks ago

You can try asking the question to the JAX team to see what they recommend in this case (basically, how to use your get_array_slice_and_repeat inside a jit function)

There may be a way to convert tracer to a concrete value using pure_callback(...), but I am not sure. My worry is on the other side though which is: Does this mean we will never be able to pass singleton array to predict where it is required to do some operation within the model e.g. retrieving the KV cache values for each attention layer?

AakashKumarNain commented 2 weeks ago

@AakashKumarNain ,

Is the pos argument static? Or is it dynamic per sample? Could it be a configuration of the model instead of an input?

Thanks, Fabien

@hertschuh this can't be a configuration on the model side. It's per sample

hertschuh commented 2 weeks ago

@hertschuh this can't be a configuration on the model side. It's per sample

@AakashKumarNain the reason why the jax.jit works with get_array_slice_and_repeat the way you have it is that you marked the pos argument as static with static_argnums. What this means is that it's not dynamic, it's a constant. The function is retraced and recompiled for each new value of pos. This won't work with dynamic values per sample.

I looked into jax.lax.dynamic_slice, but it won't work in this case because the size of the slices is dynamic (one of the size is pos). Basically, there is no way around this, JAX doesn't support dynamic output shapes. The output of the model has to have a fixed shape for some fixed shape of the input. This is not the case in your test model.

The only way to make it work is to have some operation afterwards that brings it back to a fixed size (like having a reduction or adding padding). And you may need to use masking instead of slicing. The discussion on this issue seems relevant: https://github.com/google/jax/issues/1007

AakashKumarNain commented 2 weeks ago

@hertschuh I think either you misunderstood or maybe I failed to clarify. Let me elaborate a bit more.

The pos argument comes from a loop where the model.predict(...) method, something like this:

for pos in range(output_length):
    out = model.predict([..., pos])   

This is what I meant when I said it's per sample. Also, this is a full-fledged example for your reference to understand why it is this way.

Coming back to the arguments:

the reason why the jax.jit works with get_array_slice_and_repeat the way you have it is that you marked the pos argument as static with static_argnums.

Yes, that means everytime the pos changes we will have a recompilation but it's a one time cost. Once compiled for each length, we won't have to recompile again.

I looked into jax.lax.dynamic_slice, but it won't work in this case because the size of the slices is dynamic

Basically the method I came up with is more simple in this case than using dynamic_slice(...)

And you may need to use masking instead of slicing

The problem with masking is that you can get a masked array but if you apply repeat on top of it, you won't get the expected values

hertschuh commented 2 weeks ago

@AakashKumarNain here is one way to do it if I understand your use case correctly. The idea is that pos is a model property that you can change instead of an input that you pass to predict (or call).

It is important to recompile the model every time pos is changed to force the model to retrace call before doing predict. (There are other ways to do that, but compile is the only public API for this).

def get_array_slice_and_repeat(x, pos, num_repeats):
    x_slice = x[:, :pos, :, :]
    x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
    return x_slice

class CustomModel(keras.Model):
    def __init__(self, name=None):
        super().__init__()
        self.name = name if name is not None else "custom_model"
        self.pos = None

    def call(self, inputs, training=True):
        x = get_array_slice_and_repeat(inputs, self.pos, 2)
        return x

x = np.random.rand(1, 5, 4, 3)
model = CustomModel()

model.pos = 3
model.compile()
out = model.predict(x)

model.pos = 4
model.compile()
out = model.predict(x)
AakashKumarNain commented 2 weeks ago

Thanks @hertschuh for the detailed explanation and the workaround.

PS: Do you think there can be a better way to do this? This is a common pattern, and I feel there can be a better solution to it in general

hertschuh commented 2 weeks ago

@AakashKumarNain this pattern works "naturally" with Tensorflow and Torch. Unfortunately, it doesn't work with JAX because of the fixed size requirement when jitting. The general solution for JAX is to use a fixed total size and use padding / masking.

hertschuh commented 2 weeks ago

Closing now, let me know if you have more questions.

google-ml-butler[bot] commented 2 weeks ago

Are you satisfied with the resolution of your issue? Yes No

AakashKumarNain commented 2 weeks ago

@AakashKumarNain this pattern works "naturally" with Tensorflow and Torch. Unfortunately, it doesn't work with JAX because of the fixed size requirement when jitting. The general solution for JAX is to use a fixed total size and use padding / masking.

I am okay closing the issue for now but the constraint is not on the JAX side (we can compile that). The problem is on our side as we expect all the inputs passed in a list to have the same dimensions which doesn't make sense for many cases. For JAX backend, ideally there should be flexibility within the predict method for the end user to choose which parts to compile

hertschuh commented 2 weeks ago

@AakashKumarNain

For JAX backend, ideally there should be flexibility within the predict method for the end user to choose which parts to compile

Oh, that flexibility exists. All you have to do is turn jit_compile off in model.compile() and then jit the parts that you can. The jitted_call method below would have the bulk of the code. And call would only do the strict minimum, which is separating inputs and using .item() for the static parts.

class CustomModel(keras.Model):
    def __init__(self, name=None):
        super().__init__(name=name)

    @partial(jax.jit, static_argnums=(0, 2, 3))
    def jitted_call(self, x, pos, num_repeats):
        x_slice = x[:, :pos, :, :]
        x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
        return x_slice

    def call(self, inputs, training=True):
        x, pos = inputs
        return self.jitted_call(x, pos.item(), 2)

x = np.random.rand(1, 5, 4, 3)
pos = jnp.array([2])
model = CustomModel()
model.compile(jit_compile=False)
model.predict((x, pos))

And you also have the option of overriding model.predict_step() if that helps.

AakashKumarNain commented 2 weeks ago

This is nice! BTW what about layer level within the call method of model. Is the semantics same for that as well?

hertschuh commented 1 week ago

@AakashKumarNain yes, the same semantics apply to layers.

AakashKumarNain commented 1 week ago

Perfect! Thanks @hertschuh for the explanation. One last thing before I wrap up this thread (and I think this thread will become a good reference for such workflows): Let's say you have a model that have some n custom layers, and the above operation is done in those layers, then should we jit the individual layers?


class SimpleLayer(layers.Layer):
    def __init__(self, name=None):
        super().__init__(name=name)

    @partial(jax.jit, static_argnums=(0, 2, 3))
    def jitted_call(self, x, pos, num_repeats):
        x_slice = x[:, :pos, :, :]
        x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
        return x_slice

    def call(self, inputs, training=True):
        x, pos = inputs
        return self.jitted_call(x, pos.item(), 2)

class CustomModel(keras.Model):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.layer1 = layers.Dense(32)
        self.simple_layers = [SimpleLayer() for _ in range(5)]

    def call(self, inputs, training=True):
        x = self.layer1(x)
        for i, layer in enumerate(self.simple_layers):
              x = self.layer2((x, i))
hertschuh commented 1 week ago

@AakashKumarNain Yes, that is the unfortunate part. If you can't jit at the model level, you'll have to jit each individual layer that supports it. And chances are the result will be slower than a fully jitted model.

AakashKumarNain commented 1 week ago

Cool. Thanks @hertschuh for the explanation