Open DLumi opened 2 weeks ago
@DLumi I suspect this is not an issue on Keras 3 actually. Keras 2 actually caches an attribute on self, which totally makes sense that it might mess up in a fit()
call fails in the middle. But Keras 3 just passes and additional kwarg, so there shouldn't be anything stateful to mess up.
If you think we could apply the same approach to tf-keras, you are welcome to open up a PR there. Otherwise we will probably stick to this being fixed on Keras 3. I will close this for now, but if you can recreate a bug on Keras 3 please re-open!
(Also, as to why this exists, yes it's to avoid some overhead the creating the dataset iterator. But Keras 3 handles this much more elegantly than Keras 2)
But Keras 3 just passes and additional kwarg, so there shouldn't be anything stateful to mess up.
Uh, I'm pretty sure it's functionally exactly the same as in Keras 2, as I see little to no change in actual code. Here's Keras 2 code for comparison: https://github.com/keras-team/tf-keras/blob/c5f97730b2e495f5f56fc2267d22504075e46337/tf_keras/engine/training.py#L2236C1-L2241C55
Maybe I am missing something here, though?
Anyways, it would greatly help if I knew how to recreate first-working-then-failing tf.Dataset on the toy scale. This way I could recreate the setup, and potentially give you a definite way to reproduce this with Keras 3.
Ah my bad, I misread the code. Still there is a key difference between Keras 2 and Keras 3 here. This line
In Keras 3, we always clear the cached dataset at the beginning of fit. Which is not true in Keras 2. So I see how a crashing fit could cause an issue in Keras 2, but not in Keras 3.
As for a crashing a dataset, maybe something like this.
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices(tf.range(100))
@tf.py_function(Tout=tf.int32)
def crasher(x):
if x > 50:
raise ValueError
return x
ds = ds.map(crasher)
for x in ds:
print(x)
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
I'll preface by saying that I encountered this issue with tf_keras == 2.15, but the source code regarding evaluation is hardly different from v2.15, I feel that it's still applicable here.
The issue is that no matter what
fit
forcesevaluate
to use stored dataset object for the validation step instead of whatever object you supply to fit. This is super obscure, but it's probably done for some performance reasons, so whatever. Why is this an issue? If you change something about your dataset (like, initially you forgot to turn on.ignore_errors()
) mid training, and then you pass the new DS instance tofit
, it completely ignores this fact. And in this particular case, it would fail if any errors arise on the DS preprocessing steps.Yes, you can cure it by
model._eval_data_handler = None
, which in turn forcesevaluate
to cache the new object, but to figure this out, you have to spend some time on diving into the source code.So what I propose is: 1) a mention about said functionality in
fit
's documentation 2) some actual public API for either cleaning cached validation objects, or disabling caching behavior entirelyP.S. I'd provide a colab link, but it turns out that making a tf.Dataset that randomly fails when I want it to is actually way harder than it seems