Open RishabhMalviya opened 1 year ago
So, I did some digging around.
Based on my experiments, I observed that no matter what I do, ModelCheckpoint
saves checkpoints to the directory I run the script from. Unless I give it a dirpath
, then it saves the checkpoints in the dirpath
.
Based on the source code, this makes sense:
ModelCheckpoint
sets the value of its internal variable self.dirpath
through its __init_ckpt_dir()
function: https://github.com/Lightning-AI/lightning/blob/6eae2310d6dae086596e5bdddd08e8cd3884336e/src/lightning/pytorch/callbacks/model_checkpoint.py#L442__init_ckpt_dir()
function gets called in the __init__()
function: https://github.com/Lightning-AI/lightning/blob/6eae2310d6dae086596e5bdddd08e8cd3884336e/src/lightning/pytorch/callbacks/model_checkpoint.py#L246In any case, its impossible to get ModelCheckpoint
to use a logger
's save_dir
as the save location. I thought this should be possible when I saw this docstring ro __resolve_ckpt_dir()
: https://github.com/Lightning-AI/lightning/blob/6eae2310d6dae086596e5bdddd08e8cd3884336e/src/lightning/pytorch/callbacks/model_checkpoint.py#L580
But that resolution has no effect, because the ModelCheckpoint
's self.dirpath
value isalready set during __init__()
.
I think it might be worthwhile to remove/modify the __resolve_ckpt_dir()
function in ModelCheckpoint
since it has no effect. Also, we should remove any mention of coupling between logger
s save_dir
and ModelCheckpoint
save locations in the documentation.
I am encountering something similar to this issue as well.
Using MLFlowLogger
with log_model = True
and a tracking_uri
set to some HTTP URL results in the checkpoints being saved both as artifacts in the MLFlow tracking server and in subdirectories of whatever directory is configured as the trainer's root. This issue doesn't occur when using save_dir
.
Also having this issue when trying to use the MLFlowLogger
and a ModelCheckpoint
callback together. Something in the MLFLowLogger overrides the dirpath
specified in provided checkpoint callback. Specifying a save_dir
when initializing the logger also didn't help. This is pretty unintuitive behavior, imo it's a bug.
I didn't trace through to see exactly where this happens, but I was able to work around it by modifying the trainer's checkpoint callback directly:
model = LightningModule(
callbacks = [ModelCheckpoint(...), ],
)
trainer = Trainer(
logger=MLFlowLogger,
)
trainer.checkpoint_callback.dirpath = my_desired_dirpath
trainer.fit( model=model)
This might be a little bit offtopic, but wouldn't it be better if the mlflowLogger completely took over the model check pointing behavior if it is used? ModelCheckpoint is currently designed in a way to save to a file, whereas the mlflow api is agnostic of files. Another way I could imagine, is to subclass ModelCheckpoint and change its behavior tor log to ml flow.
Something along the lines of this:
class MLflowModelCheckpoint(ModelCheckpoint):
def __init__(...):
super().__init__(...)
if not mlflow_run():
raise Exception('Not an MLFlow run')
# ...
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
super.on_train_start(...)
mlflow.register_model(...)
def _save_checkpoint(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
checkpoint = trainer.get_checkpoint(...) # This method currently doesnt exist
mlflow.log_model(checkpoint, ...)
I think the main conflict in using mlflow and lightning together is, that the Trainer currently not designed in a way, to be able to delegate checkpoint saving.
I also want to add that MLFlowLogger currently only copies checkpoints from a file: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loggers/mlflow.py#L335
Bug description
I am trying to achieve the following behavior:
ModelCheckpoint
callbacks save model checkpoint files to a certain locationMLFlowLogger
(withlog_model=True
) only references the saved checkpointsThe problem is that no matter what I do,
MLFlowLogger
tries to save copies of the checkpoints in a new location.What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
The above code saves model checkpoints in the
tracking_uri
location of theMLFlowLogger
even though checkpoints already exist in the directory from which I ran the script (which is where theModelCheckpoint
callbacks are saving it by default.Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response