skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.69k stars 384 forks source link

Enable using a generator as data loader #1011

Open BenjaminBossan opened 10 months ago

BenjaminBossan commented 10 months ago

As discussed with @ottonemo

In #835, we made a change that results in the data loader being only initialized once per fit call, not once per iteration (i.e. epoch) as previously. Although we considered backwards compatibility, we missed one use case, namely passing a generator to iterator_train and iterator_valid (or, more precisely, a generator factory).

Now, if a user tries this, the training will run fine for the first epoch but afterwards the generator will be exhausted, resulting in the fit loop running 0 times, which creates strange results.

This issue should be addressed by two follow up tasks:

  1. Re-enable passing a generator (factory) to iterator_train and iterator_valid. This can be achieved by checking the type and, if a types.GeneratorType is encountered, to wrap it into a class that makes it an iterator.
  2. Place a guard inside the fit loop that detects if it is called 0 times. I would imagine this is never desired and should raise an error, or at the very least give a warning. As is, there is no error and no warning, but the results are very strange in the print log.