keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.47k forks source link

Training hangs at the end of the first epoch when using a PyDataset and workers > 1. #20425

Open HGS-mbayer opened 6 days ago

HGS-mbayer commented 6 days ago

Training using a PyDataset and workers > 1 will hang at the end of the first epoch with Keras 3.6. This issue does not seem to occur with Keras 3.5.

Example Code

Here is a slightly modified version of https://keras.io/examples/vision/mnist_convnet/ to reproduce the issue.

import math

import keras
import numpy as np
from keras import layers

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

batch_size = 512
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

class Data(keras.utils.PyDataset):
    def __init__(self, x, y, batch_size: int = 2, **kwargs):
        super().__init__(**kwargs)
        self._x = x
        self._y = y
        self._batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self._x) / self._batch_size)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError

        indices = range(len(self._x))[
            index * self._batch_size : (index + 1) * self._batch_size
        ]

        return self._x[indices, ...], self._y[indices, ...]

training_data = Data(x_train, y_train, batch_size=batch_size, workers=8)
validation_data = Data(x_test, y_test, batch_size=batch_size, workers=8)

# This will hang at the end of the first epoch with Keras 3.6.
model.fit(training_data, epochs=epochs, validation_data=validation_data)

Traceback

Here is the traceback I receive when interrupting the process.

Epoch 1/15
117/118 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - accuracy: 0.6114 - loss: 1.2523Traceback (most recent call last):
  File "example.py", line 75, in <module>
    model.fit(training_data, epochs=epochs, validation_data=validation_data)
  File "...\env\lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler
    return fn(*args, **kwargs)
  File "...\env\lib\site-packages\keras\src\backend\torch\trainer.py", line 252, in fit
    for step, data in epoch_iterator.enumerate_epoch():
  File "...\env\lib\site-packages\keras\src\trainers\epoch_iterator.py", line 110, in enumerate_epoch
    for step, data in enumerate(self._get_iterator()):
  File "...\env\lib\site-packages\torch\utils\data\dataloader.py", line 701, in __next__
    data = self._next_data()
  File "...\env\lib\site-packages\torch\utils\data\dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "...\env\lib\site-packages\torch\utils\data\_utils\fetch.py", line 42, in fetch
    data = next(self.dataset_iter)
  File "...\env\lib\site-packages\keras\src\trainers\data_adapters\data_adapter_utils.py", line 222, in __iter__
    for batch in self.iterable:
  File "...\env\lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py", line 257, in _finite_enqueuer_generator
    for i, batch in enumerate(self.enqueuer.get()):
  File "...\env\lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py", line 637, in get
    value = self.future_queue.get(block=True, timeout=5)
  File "...\AppData\Local\Programs\Python\Python310\lib\queue.py", line 180, in get
    self.not_empty.wait(remaining)
  File "...\AppData\Local\Programs\Python\Python310\lib\threading.py", line 324, in wait
    gotit = waiter.acquire(True, timeout)
KeyboardInterrupt
fchollet commented 1 day ago

Thanks for the report.

This issue appears to have been introduced in https://github.com/keras-team/keras/commit/fd8bbe2284f1ddfbc2578fce9cc5b2af35b7c927

@hertschuh can you take a look? I started debugging it, and here's my reading: the following code

except queue.Empty:
    pass

is reached and leads to an infinite loop. That's because we never get to the exit condition:

if i >= num_batches - 1:
    self.enqueuer.stop()
    return

which is because def num_batches returns a (correct) number that is larger than the actual number of batches drawable for the first epoch

fchollet commented 1 day ago

I added a workaround at HEAD to continue training when the issue occur. It's not a definitive solution but it should help.