keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.62k stars 19.42k forks source link

Obscure validation failure due to `_use_cached_eval_dataset` #20177

Open DLumi opened 2 weeks ago

DLumi commented 2 weeks ago

I'll preface by saying that I encountered this issue with tf_keras == 2.15, but the source code regarding evaluation is hardly different from v2.15, I feel that it's still applicable here.

The issue is that no matter what fit forces evaluate to use stored dataset object for the validation step instead of whatever object you supply to fit. This is super obscure, but it's probably done for some performance reasons, so whatever. Why is this an issue? If you change something about your dataset (like, initially you forgot to turn on .ignore_errors()) mid training, and then you pass the new DS instance to fit, it completely ignores this fact. And in this particular case, it would fail if any errors arise on the DS preprocessing steps.

Yes, you can cure it by model._eval_data_handler = None, which in turn forces evaluate to cache the new object, but to figure this out, you have to spend some time on diving into the source code.

So what I propose is: 1) a mention about said functionality in fit's documentation 2) some actual public API for either cleaning cached validation objects, or disabling caching behavior entirely

P.S. I'd provide a colab link, but it turns out that making a tf.Dataset that randomly fails when I want it to is actually way harder than it seems

mattdangerw commented 2 weeks ago

@DLumi I suspect this is not an issue on Keras 3 actually. Keras 2 actually caches an attribute on self, which totally makes sense that it might mess up in a fit() call fails in the middle. But Keras 3 just passes and additional kwarg, so there shouldn't be anything stateful to mess up.

https://github.com/keras-team/keras/blob/d4a51168bfedf69a9aae7ddff289277972dfd85d/keras/src/backend/tensorflow/trainer.py#L391-L397

If you think we could apply the same approach to tf-keras, you are welcome to open up a PR there. Otherwise we will probably stick to this being fixed on Keras 3. I will close this for now, but if you can recreate a bug on Keras 3 please re-open!

(Also, as to why this exists, yes it's to avoid some overhead the creating the dataset iterator. But Keras 3 handles this much more elegantly than Keras 2)

google-ml-butler[bot] commented 2 weeks ago

Are you satisfied with the resolution of your issue? Yes No

DLumi commented 2 weeks ago

But Keras 3 just passes and additional kwarg, so there shouldn't be anything stateful to mess up.

Uh, I'm pretty sure it's functionally exactly the same as in Keras 2, as I see little to no change in actual code. Here's Keras 2 code for comparison: https://github.com/keras-team/tf-keras/blob/c5f97730b2e495f5f56fc2267d22504075e46337/tf_keras/engine/training.py#L2236C1-L2241C55

Maybe I am missing something here, though?

Anyways, it would greatly help if I knew how to recreate first-working-then-failing tf.Dataset on the toy scale. This way I could recreate the setup, and potentially give you a definite way to reproduce this with Keras 3.

mattdangerw commented 2 weeks ago

Ah my bad, I misread the code. Still there is a key difference between Keras 2 and Keras 3 here. This line

https://github.com/keras-team/keras/blob/4c71314cfa51e462a3a7ebcbd27dc52d8b788bc2/keras/src/backend/tensorflow/trainer.py#L262

In Keras 3, we always clear the cached dataset at the beginning of fit. Which is not true in Keras 2. So I see how a crashing fit could cause an issue in Keras 2, but not in Keras 3.

As for a crashing a dataset, maybe something like this.

import tensorflow as tf

ds = tf.data.Dataset.from_tensor_slices(tf.range(100))

@tf.py_function(Tout=tf.int32)
def crasher(x):
    if x > 50:
        raise ValueError
    return x

ds = ds.map(crasher)

for x in ds:
    print(x)
github-actions[bot] commented 4 days ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.