Open vijayi1 opened 2 months ago
I dug into this further. When training is resumed from a non-zero epoch, RayDatasetBatcher (ludwig/data/dataset/ray.py) calls self._fetch_next_epoch() twice, once during the class init method and again in the set_epoch() method, without consuming any batches in between. The following patch fixes this problem, but I'm not sure whether it's the right fix -
diff --git a/ludwig/data/dataset/ray.py b/ludwig/data/dataset/ray.py
index 5ad083fa..ba53ad33 100644
--- a/ludwig/data/dataset/ray.py
+++ b/ludwig/data/dataset/ray.py
@@ -352,7 +352,8 @@ class RayDatasetBatcher(Batcher):
def set_epoch(self, epoch, batch_size):
self.batch_size = batch_size
if epoch != self._epoch:
- self._fetch_next_epoch()
+ if self._step or self._last_batch:
+ self._fetch_next_epoch()
self._epoch = epoch
@property
Describe the bug
When resuming a model train (retraining) with Ray, using a small dataset the following exception occurs -
The full exception is attached: exception_stack_trace.txt
To Reproduce Steps to reproduce the behavior:
2 run the attached first_run.py in that folder (it uses the config.yaml file from examples/mnist) first_run.py.txt
you should see the error when running second_run.py
Expected behavior The second run should succeed training.
Environment (please complete the following information):
Additional context