Open fschlatt opened 1 year ago
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
I have the same problem.
I can reproduce this error
I fixed this by using the ModelCheckpoint
callback of lightning
:
from pathlib import Path
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader
import torch
class MyDataModule(Dataset):
def __getitem__(self, index) -> torch.Tensor:
return torch.rand(1)
def __len__(self) -> int:
return 5
class MyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Linear(1, 1)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
return self.model(batch).mean()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
model = MyLightningModule()
data = MyDataModule()
loader = DataLoader(data, batch_size=1)
trainer = Trainer(
logger=WandbLogger(log_model=True),
callbacks=[
ModelCheckpoint(dirpath=Path(wandb.run.dir) / "checkpoints"),
]
)
trainer.fit(model, loader)
This logs to /run/files/checkpoints
, which is desired. Although I expected this code to checkpoint to /run/checkpoint
.
My fix has been to use a CustomWandbLogger
. This logs to run-*-{run_id}/files/{run_id}/checkpoints
which isn't perfect but does the trick:
edit: The DummyExperiment
is necessary for distributed training. self.experiment
is a DummyExperiment
in all processes except for rank=0. The correct save_dir
gets broadcasted to all processes internally in lightning.
from typing import Optional
from lightning.pytorch.loggers import WandbLogger
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
class CustomWandbLogger(WandbLogger):
@property
def save_dir(self) -> Optional[str]:
"""Gets the save directory.
Returns:
The path to the save directory.
"""
if isinstance(self.experiment, DummyExperiment):
return None
return self.experiment.dir
trainer = Trainer(
logger=CustomWandbLogger(log_model=True)
)
Bug description
When using the Wandb Logger and setting the
log_model=True
, the model checkpoint isn't saved in the wandb experiment directory, but in the separate lightning logs directory.Current behavior:
Desired behavior:
What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```More info
I see two problems which need to be addressed:
The WandbLogger doesn't use the experiment's files directory as its
save_dir
property This can can fixed by replacing this line) byself.experiment.dir
The ModelCheckpoint adds the name and version of the logger to the checkpoint path when resolving the checkpoint directory. The culprit is the
__resolve_ckpt_dir
function. I would propose moving the name and version of the logger as well as the save_dir function to the actual logger, as each logger may have a different strategy for integrating the version and name. For example, wandb does this by default and creates the files directory for saving artifacts.cc @awaelchli @morganmcg1 @borisdayma @scottire @parambharat