Open mmdanziger opened 7 months ago
Thanks for the question! There is not currently a way to save and restore the state of the ExperimentDataPipe
, but clearly that would be a useful enhancement for checkpointing support. Since you want to checkpoint at hourly intervals, it sounds like you are interested in being able to checkpoint in the middle of an epoch, rather than at the end of an epoch, correct?
I believe this could be accomplished by adding support to save and restore the current epoch, the current batch, and the state of the random number generator (used to perform shuffling). I'll raise this with the team to see if it's something we can consider adding.
Thanks for your comment. Mid-epoch restoring is what we are interested in, including RNG state. So you could (at least approximately) stop and resume a training job, even if it would take many many days and you are bound to a queue with 24h max-length jobs. The idea would be to integrate this with Pytorch Lightning DataModule state_dict
support as in https://lightning.ai/docs/pytorch/stable/extensions/datamodules_state.html
One could implement a simple batch counter and then re-stream that number of batches with the same seed on restore. Like the old 90s VHS rewinding machines 😆 . Surely there's a better way to save and restore iterator state, but from a quick look at the code, I wasn't able to see it.
Great work! Is there a way to save and restore the state of the IterableDataPipe? The use case is to train for N hours, stop and then resume training at a given point in the iterable. Poking around I could not see any obvious ways in the code. https://github.com/chanzuckerberg/cellxgene-census/blob/3b935f1199c8ac71245884c0cf6e740fbf8545d8/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py#L329