Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.49k stars 3.39k forks source link

Continuing training resets logger step #6435

Closed wittenator closed 3 years ago

wittenator commented 3 years ago

🐛 Bug

I am running Pytorch Lightning in a federated learning setting. Therefore I have several models and I need to instantiate a Trainer object for one model multiple times. Every time I do that the associated logger resets the epoch and logs the metrics on top of each other in the plots. Since instantiating a new Trainer object to continue training a model is allowed as far as I know: Do you know if that is expected behaviour and whether there is a workaround?

In this picture is an example of the logging output of three consecutively called trainers with a common logger. image

Please reproduce using the BoringModel

The Colab currently does not work with lightning and lightning_bolts.

To Reproduce

Create a TestTubeLogger and instantiate multiple Trainers with a common logger for the same model and fit the trainers consecutively.

Expected behaviour

The x-axis of the logged metrics should not reset between runs.

Environment

tchaton commented 3 years ago

Dear @wittenator,

Concerning your issue, did you try to re-create the Trainer with resume_from_checkpoint ? Could share a code snippet.

    model = BoringModel()
    trainer = Trainer(
        accelerator='ddp_sharded_spawn',
        num_processes=2,
        fast_dev_run=True,
    )

    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, 'model.pt')
    trainer.save_checkpoint(checkpoint_path)

    model = BoringModel()

    trainer = Trainer(
        accelerator='ddp_sharded_spawn',
        num_processes=2,
        fast_dev_run=True,
        resume_from_checkpoint=checkpoint_path,
    )

    trainer.fit(model)

Lightning is currently in the process to become more friendly for Federated Learning. We would love for you to contribute if you feel like ! See this PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/6212

Best, T.C

wittenator commented 3 years ago

Hi @tchaton ,

thanks for the tip with the reloading! I didn't try that yet because saving and reading is a bit of overhead. I would look into adding persistent step/epoch counting. As far as I see it, that should be a pretty small addition since there must be some code to count steps and epochs in order to display that in the logger for one trainer. I would only need to persist and accumulate that data in the model. Do you have some pointers into which class I should look?

ananthsub commented 3 years ago

@wittenator #6429 tracks this

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!