Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.97k stars 3.35k forks source link

WandbLogger disables cloud checkpointing in Trainer default_root_dir #16195

Open turian opened 1 year ago

turian commented 1 year ago

Bug description

Cloud checkpoints are cool! But once you use the WandbLogger, no cloud checkpoints (or anything really) is saved to trainer.default_root_dir. The model is checkpointed as a Wandb artifact, which is cool, but I want it also in trainer.default_root_dir's s3 bucket.

There reason I want this:

Related bug Lightning-AI/pytorch-lightning#16196 . See 'More info' at the bottom of this issue.

There are some related issues: https://github.com/Lightning-AI/lightning/pull/14325 https://github.com/Lightning-AI/lightning/issues/5935 https://github.com/Lightning-AI/lightning/issues/11769 https://github.com/Lightning-AI/lightning/issues/15539 https://github.com/Lightning-AI/lightning/issues/2318 https://github.com/Lightning-AI/lightning/issues/2161 but I haven't found this specifically.

How to reproduce the bug

Here is a google colab that replicates this and a related bag. I share the code for both because it's easier to configure the AWS credentials and see both bugs simultaneously.

Copying and pasting the most important bit (but see the colab for a full minimal replication):

from pytorch_lightning.loggers import WandbLogger

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    logger = WandbLogger(
        project="boringbug",
        log_model="all",
    )

    model = BoringModel()
    trainer = Trainer(
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=logger,
        default_root_dir = f"{BORING_BUCKET}/wandbtest/"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

run()

### Error messages and logs

There is no error message, but `{BORING_BUCKET}/wandbtest/` (an S3 location) is empty, and the checkpoint is only in Wandb.

### Environment

More info

What I really want for christmas this year, all packaged together:

cc @awaelchli @morganmcg1 @borisdayma @scottire @parambharat @manangoel99

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!