chanzuckerberg / cellxgene-census

CZ CELLxGENE Discover Census
https://chanzuckerberg.github.io/cellxgene-census/
MIT License
84 stars 22 forks source link

Save and restore state of ExperimentDataPipe to resume iteration #1041

Open mmdanziger opened 7 months ago

mmdanziger commented 7 months ago

Great work! Is there a way to save and restore the state of the IterableDataPipe? The use case is to train for N hours, stop and then resume training at a given point in the iterable. Poking around I could not see any obvious ways in the code. https://github.com/chanzuckerberg/cellxgene-census/blob/3b935f1199c8ac71245884c0cf6e740fbf8545d8/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py#L329

atolopko-czi commented 7 months ago

Thanks for the question! There is not currently a way to save and restore the state of the ExperimentDataPipe, but clearly that would be a useful enhancement for checkpointing support. Since you want to checkpoint at hourly intervals, it sounds like you are interested in being able to checkpoint in the middle of an epoch, rather than at the end of an epoch, correct?

I believe this could be accomplished by adding support to save and restore the current epoch, the current batch, and the state of the random number generator (used to perform shuffling). I'll raise this with the team to see if it's something we can consider adding.

mmdanziger commented 6 months ago

Thanks for your comment. Mid-epoch restoring is what we are interested in, including RNG state. So you could (at least approximately) stop and resume a training job, even if it would take many many days and you are bound to a queue with 24h max-length jobs. The idea would be to integrate this with Pytorch Lightning DataModule state_dict support as in https://lightning.ai/docs/pytorch/stable/extensions/datamodules_state.html

One could implement a simple batch counter and then re-stream that number of batches with the same seed on restore. Like the old 90s VHS rewinding machines 😆 . Surely there's a better way to save and restore iterator state, but from a quick look at the code, I wasn't able to see it.