Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.54k stars 3.39k forks source link

No action for key "ckpt_path" -> ckpt_path not available for linking #18885

Open Toekan opened 1 year ago

Toekan commented 1 year ago

Bug description

Hi,

Thanks for all the hard work on making it possible to configure Lightning experiments through a simple config!

I want to link my ckpt_path to a callback using link_arguments together wit h LightningCLI (in my case because the callback is used to save out a set of predictions and the ckpt_path is used for naming the prediction set filename, but I would have thought needing your ckpt_path in other places in the config.yaml isn't that uncommon?). This is how I implemented the linking.

class MyLightningCLI(LightningCLI):
        def add_arguments_to_parser(self, parser):
            parser.link_arguments("ckpt_path", "trainer.callbacks.init_args.ckpt_path")

cli = MyLightningCLI(
        MyLitModule,
        MyLitDataModule,
        run=True,
    )

when running python predict_my_model.py predict --config my_config.yaml I unfortunately get the following error:

ValueError: No action for key "ckpt_path". Going through the code, it seems like ckpt_path does not have an action attached it, find_parent_or_child_actions does not find one.


I've first incorrectly raised this on jsonargparse, where I got the following response:

The problem is not in jsonargparse. The error happens because ckpt_path is added in line cli.py#L497, which is after add_arguments_to_parser gets called (line cli.py#L494). That is, when the link_arguments is run, ckpt_path does not yet exist in the parser.

How can this be fixed? You could override _prepare_subcommand_parser, having the same code, but moving _add_arguments to be after add_methodarguments. Though, note that this method starts with underscore , so not guaranteed to be stable.

There could be other more proper solutions. But maybe this is not the correct place to discuss it. Please create an issue in lightning.


Thanks!

What version are you seeing the problem on?

v2.0

How to reproduce the bug

See code above, can make a more complete example if needed.

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @carmocca @mauvilsa

Toekan commented 1 year ago

Not sure if it is related, but I also don't manage to load a checkpoint when using run=False , as flagged here:

https://github.com/Lightning-AI/lightning/issues/12302

Which is a pretty important use-case I think, considering the checkpoint file has basically replaced the old hparams approach.

mauvilsa commented 1 year ago

Not sure if it is related, but I also don't manage to load a checkpoint when using run=False , as flagged here:

12302

Which is a pretty important use-case I think, considering the checkpoint file has basically replaced the old hparams approach.

@Toekan it is not related. There you can see the explanation and what to do.

Toekan commented 1 year ago

Thanks for the quick response!

Going off-topic a bit here (sorry, feel free to tell me if I should move it :)). Is there no easier way to load back in the whole state of the trainer or the model weights from the checkpoint file?

After reading around and trying things out for hours, the only working way I could come up with was:

cli = LightningCLI(
      MyLitModelModule,
      MyLitDataModule,
      run=False,
  )

  model = cli.model.load_from_checkpoint(
      "lightning_logs/version_xx/checkpoints/my_checkpoint.ckpt"
      # Here I need to pass in every argument that expects an instantiated class by hand
      model=cli.model.model,
      loss_fn=cli.model.loss_fn,
      activation=cli.model.activation,
      train_metrics=cli.model.train_metrics,
      ...
  )

Is this the easiest way to achieve loading the model from a lightningCLI checkpoint? Having to pull every instantiated class from the instantiated cli, just to be able to do load_from_checkpoint is obviously a considerably worse experience than what run=True has to offer.

I understand the strict distinction you are trying to create between config files for configuration, a new CLI for changes in source code (very happy LightningCLI didn't go down the jinja route), but I find it hard to fully understand where checkpoints sit in this or why they have to be linked to trainer commands rather than to the trainer itself.

carmocca commented 9 months ago

I believe #18105 will help here

calvinshopify commented 2 months ago

Hey @mauvilsa any suggestions or resolutions on this one? I am running into the same problem where:

  1. I am linking arguments via parser.link_arguments
  2. Those linked arguments are not included in config.yml or hparams.yml
  3. As a result, attempting to load from a checkpoint misses the linked args
mauvilsa commented 1 month ago

@calvinshopify what version of lightning are you using? #18105 which was included in lightning 2.3 was intended to add support for load_from_checkpoint. If you are using the latest version of lightning, what do you get if you:

import torch
ckpt = torch.load('path/to/your/saved.ckpt')
print(ckpt['hyper_parameters'])
mauvilsa commented 1 month ago

Note that there might be a bug according to #20311