ludwig-ai / ludwig

Low-code framework for building custom LLMs, neural networks, and other AI models
http://ludwig.ai
Apache License 2.0
10.97k stars 1.18k forks source link

Ray retraining fails with StopIteration exception when retraining a model with small datasets #3991

Open vijayi1 opened 2 months ago

vijayi1 commented 2 months ago

Describe the bug

When resuming a model train (retraining) with Ray, using a small dataset the following exception occurs -

    2024-04-08 13:13:36,849 WARNING worker.py:1866 -- Traceback (most recent call last):
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/data/dataset_pipeline.py", line 226, in iter_batches
        blocks_owned_by_consumer = self._peek()._plan.execute()._owned_by_consumer
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/data/dataset_pipeline.py", line 1319, in _peek
        first_dataset_gen = next(dataset_iter)
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/data/dataset_pipeline.py", line 732, in __next__
        raise StopIteration
    StopIteration

    The above exception was the direct cause of the following exception:

    Traceback (most recent call last):
      File "python/ray/_raylet.pyx", line 850, in ray._raylet.execute_task
      File "python/ray/_raylet.pyx", line 902, in ray._raylet.execute_task
      File "python/ray/_raylet.pyx", line 857, in ray._raylet.execute_task
      File "python/ray/_raylet.pyx", line 861, in ray._raylet.execute_task
      File "python/ray/_raylet.pyx", line 803, in ray._raylet.execute_task.function_executor
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/_private/function_manager.py", line 674, in actor_method_executor
        return method(__ray_actor, *args, **kwargs)
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 466, in _resume_span
        return method(self, *_args, **_kwargs)
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/train/_internal/worker_group.py", line 31, in __execute
        raise skipped from exception_cause(skipped)
      File "/data/vijayi/dl_venv/lib64/python3.8/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
        train_func(*args, **kwargs)
      File "/data/vijayi/ludwig/ludwig/backend/ray.py", line 501, in <lambda>
        lambda config: train_fn(**config),
      File "/data/vijayi/ludwig/ludwig/backend/ray.py", line 215, in train_fn
        results = trainer.train(train_shard, val_shard, test_shard, return_state_dict=True, **kwargs)
      File "/data/vijayi/ludwig/ludwig/distributed/base.py", line 157, in wrapped
        res = fn(*args, **kwargs)
      File "/data/vijayi/ludwig/ludwig/trainers/trainer.py", line 1038, in train
        batcher.set_epoch(progress_tracker.epoch, progress_tracker.batch_size)
      File "/data/vijayi/ludwig/ludwig/data/dataset/ray.py", line 355, in set_epoch
        self._fetch_next_epoch()
      File "/data/vijayi/ludwig/ludwig/data/dataset/ray.py", line 380, in _fetch_next_epoch
        self._fetch_next_batch()
      File "/data/vijayi/ludwig/ludwig/data/dataset/ray.py", line 389, in _fetch_next_batch
        self._next_batch = next(self.dataset_batch_iter)
      File "/data/vijayi/ludwig/ludwig/data/dataset/ray.py", line 469, in async_read
        raise batch
      File "/data/vijayi/ludwig/ludwig/data/dataset/ray.py", line 454, in producer
        for batch in pipeline.iter_batches(prefetch_blocks=0, batch_size=batch_size, batch_format="pandas"):
    RuntimeError: generator raised StopIteration

The full exception is attached: exception_stack_trace.txt

To Reproduce Steps to reproduce the behavior:

  1. clone the ludwig repo, then cd to the examples/mnist/ folder.

2 run the attached first_run.py in that folder (it uses the config.yaml file from examples/mnist) first_run.py.txt

  1. retrain the model by running the attached second_run.py (in the same folder) second_run.py.txt

you should see the error when running second_run.py

Expected behavior The second run should succeed training.

Environment (please complete the following information):

Additional context

vijayi1 commented 2 months ago

I dug into this further. When training is resumed from a non-zero epoch, RayDatasetBatcher (ludwig/data/dataset/ray.py) calls self._fetch_next_epoch() twice, once during the class init method and again in the set_epoch() method, without consuming any batches in between. The following patch fixes this problem, but I'm not sure whether it's the right fix -

            diff --git a/ludwig/data/dataset/ray.py b/ludwig/data/dataset/ray.py
            index 5ad083fa..ba53ad33 100644
            --- a/ludwig/data/dataset/ray.py
            +++ b/ludwig/data/dataset/ray.py
            @@ -352,7 +352,8 @@ class RayDatasetBatcher(Batcher):
                 def set_epoch(self, epoch, batch_size):
                     self.batch_size = batch_size
                     if epoch != self._epoch:
            -            self._fetch_next_epoch()
            +            if self._step or self._last_batch:
            +                self._fetch_next_epoch()
                         self._epoch = epoch

                 @property