Closed AakashKumarNain closed 2 weeks ago
I know that the whole call
is jitted when we call predict
, but this seems to put a restriction if we want to have some functionality inside the call that requires to be jitted explicitly with some static arguments
The documentation for x.item()
in JAX says:
Copy an element of an array to a standard Python scalar and return it.
It is not possible to access the eager value of an array inside a jitted function, which is why calling item()
inside call()
fails with predict()
-- as you note, predict()
is itself jitted.
You have two possible solutions:
item()
inside call()
, or any other eager-only method.predict()
eagerly. This can be done by setting model.run_eagerly = True
. Note that performance may suffer.@fchollet Thanks for the quick response.
Yes, I am aware of item()
. Here is the problem though: If I use predict
, then all inputs needs to have same dimensionality. If I don't use item()
or int(pos[0])
or int(ops.get_item(pos, -1))
, there is no way to retrieve the position as integer, and hence not possible to jit the get_array_slice_and_repeat(...)
method as value of pos
passed would be an object of array_impl
which isn't hashable.
Unfortunately, running eagerly is very slow for my use case.
Note: I also don't see any way to use ops.slice(...)
in this case as it will complain about the Tracer in the shape values i.e. ops.slice(x, [0, 0, 0, 0], [x_shape[0], pos+1, x_shape[2], x_shape[3]])
You can try asking the question to the JAX team to see what they recommend in this case (basically, how to use your get_array_slice_and_repeat
inside a jit function). There might be a factoring that removes the need for the eager item()
call. If not, then it's just a hard limitation of XLA.
@AakashKumarNain ,
Is the pos
argument static? Or is it dynamic per sample?
Could it be a configuration of the model instead of an input?
Thanks, Fabien
You can try asking the question to the JAX team to see what they recommend in this case (basically, how to use your get_array_slice_and_repeat inside a jit function)
There may be a way to convert tracer to a concrete value using pure_callback(...)
, but I am not sure. My worry is on the other side though which is: Does this mean we will never be able to pass singleton array to predict
where it is required to do some operation within the model e.g. retrieving the KV cache values for each attention layer?
@AakashKumarNain ,
Is the
pos
argument static? Or is it dynamic per sample? Could it be a configuration of the model instead of an input?Thanks, Fabien
@hertschuh this can't be a configuration on the model side. It's per sample
@hertschuh this can't be a configuration on the model side. It's per sample
@AakashKumarNain the reason why the jax.jit
works with get_array_slice_and_repeat
the way you have it is that you marked the pos
argument as static with static_argnums
. What this means is that it's not dynamic, it's a constant. The function is retraced and recompiled for each new value of pos
. This won't work with dynamic values per sample.
I looked into jax.lax.dynamic_slice
, but it won't work in this case because the size of the slices is dynamic (one of the size is pos
). Basically, there is no way around this, JAX doesn't support dynamic output shapes. The output of the model has to have a fixed shape for some fixed shape of the input. This is not the case in your test model.
The only way to make it work is to have some operation afterwards that brings it back to a fixed size (like having a reduction or adding padding). And you may need to use masking instead of slicing. The discussion on this issue seems relevant: https://github.com/google/jax/issues/1007
@hertschuh I think either you misunderstood or maybe I failed to clarify. Let me elaborate a bit more.
The pos
argument comes from a loop where the model.predict(...)
method, something like this:
for pos in range(output_length):
out = model.predict([..., pos])
This is what I meant when I said it's per sample. Also, this is a full-fledged example for your reference to understand why it is this way.
Coming back to the arguments:
the reason why the jax.jit works with get_array_slice_and_repeat the way you have it is that you marked the pos argument as static with static_argnums.
Yes, that means everytime the pos changes we will have a recompilation but it's a one time cost. Once compiled for each length, we won't have to recompile again.
I looked into jax.lax.dynamic_slice, but it won't work in this case because the size of the slices is dynamic
Basically the method I came up with is more simple in this case than using dynamic_slice(...)
And you may need to use masking instead of slicing
The problem with masking is that you can get a masked array but if you apply repeat
on top of it, you won't get the expected values
@AakashKumarNain here is one way to do it if I understand your use case correctly.
The idea is that pos
is a model property that you can change instead of an input that you pass to predict
(or call
).
It is important to recompile the model every time pos
is changed to force the model to retrace call
before doing predict
. (There are other ways to do that, but compile
is the only public API for this).
def get_array_slice_and_repeat(x, pos, num_repeats):
x_slice = x[:, :pos, :, :]
x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
return x_slice
class CustomModel(keras.Model):
def __init__(self, name=None):
super().__init__()
self.name = name if name is not None else "custom_model"
self.pos = None
def call(self, inputs, training=True):
x = get_array_slice_and_repeat(inputs, self.pos, 2)
return x
x = np.random.rand(1, 5, 4, 3)
model = CustomModel()
model.pos = 3
model.compile()
out = model.predict(x)
model.pos = 4
model.compile()
out = model.predict(x)
Thanks @hertschuh for the detailed explanation and the workaround.
PS: Do you think there can be a better way to do this? This is a common pattern, and I feel there can be a better solution to it in general
@AakashKumarNain this pattern works "naturally" with Tensorflow and Torch. Unfortunately, it doesn't work with JAX because of the fixed size requirement when jitting. The general solution for JAX is to use a fixed total size and use padding / masking.
Closing now, let me know if you have more questions.
@AakashKumarNain this pattern works "naturally" with Tensorflow and Torch. Unfortunately, it doesn't work with JAX because of the fixed size requirement when jitting. The general solution for JAX is to use a fixed total size and use padding / masking.
I am okay closing the issue for now but the constraint is not on the JAX side (we can compile that). The problem is on our side as we expect all the inputs passed in a list to have the same dimensions which doesn't make sense for many cases. For JAX backend, ideally there should be flexibility within the predict method for the end user to choose which parts to compile
@AakashKumarNain
For JAX backend, ideally there should be flexibility within the predict method for the end user to choose which parts to compile
Oh, that flexibility exists. All you have to do is turn jit_compile
off in model.compile()
and then jit the parts that you can. The jitted_call
method below would have the bulk of the code. And call
would only do the strict minimum, which is separating inputs and using .item()
for the static parts.
class CustomModel(keras.Model):
def __init__(self, name=None):
super().__init__(name=name)
@partial(jax.jit, static_argnums=(0, 2, 3))
def jitted_call(self, x, pos, num_repeats):
x_slice = x[:, :pos, :, :]
x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
return x_slice
def call(self, inputs, training=True):
x, pos = inputs
return self.jitted_call(x, pos.item(), 2)
x = np.random.rand(1, 5, 4, 3)
pos = jnp.array([2])
model = CustomModel()
model.compile(jit_compile=False)
model.predict((x, pos))
And you also have the option of overriding model.predict_step()
if that helps.
This is nice! BTW what about layer level within the call
method of model
. Is the semantics same for that as well?
@AakashKumarNain yes, the same semantics apply to layers.
Perfect! Thanks @hertschuh for the explanation. One last thing before I wrap up this thread (and I think this thread will become a good reference for such workflows): Let's say you have a model that have some n
custom layers, and the above operation is done in those layers, then should we jit
the individual layers?
class SimpleLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
@partial(jax.jit, static_argnums=(0, 2, 3))
def jitted_call(self, x, pos, num_repeats):
x_slice = x[:, :pos, :, :]
x_slice = jnp.repeat(x_slice, num_repeats, axis=2)
return x_slice
def call(self, inputs, training=True):
x, pos = inputs
return self.jitted_call(x, pos.item(), 2)
class CustomModel(keras.Model):
def __init__(self, name=None):
super().__init__(name=name)
self.layer1 = layers.Dense(32)
self.simple_layers = [SimpleLayer() for _ in range(5)]
def call(self, inputs, training=True):
x = self.layer1(x)
for i, layer in enumerate(self.simple_layers):
x = self.layer2((x, i))
@AakashKumarNain Yes, that is the unfortunate part. If you can't jit at the model level, you'll have to jit each individual layer that supports it. And chances are the result will be slower than a fully jitted model.
Cool. Thanks @hertschuh for the explanation
Here is a simple jitted function doing something useful
Let's say I have a keras model that needs to use that operation. Keeping this example very simple:
This fails with following error: