Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.51k stars 3.39k forks source link

load_from_checkpoint uses default parameters instead of supplied argument (size mismatch) #18587

Closed B-lanc closed 1 year ago

B-lanc commented 1 year ago

Bug description

If any of the parameters of a model has a default value, those default values will be used for checking when doing load_from_checkpoint, instead of the arguments provided for the model, resulting in size mismatch.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import lightning as L
import torch.nn as nn

class a(L.LightningModule):
    def __init__(self, testing=True):
        super(a, self).__init__()
        if testing:
            self.model = nn.Linear(2, 4)
        else:
            self.model = nn.Linear(4, 8)

model = a(False)

trainer = L.Trainer()
trainer.strategy.connect(model)
trainer.save_checkpoint("bugtesting.ckpt")
model.load_from_checkpoint("bugtesting.ckpt") #comment this one and uncomment bottom two for more weird behavior

# model2 = a(False)
# model2.load_from_checkpoint("bugtesting.ckpt") #This should error out, but it doesn't.....

### Error messages and logs

Error messages and logs here please

size mismatch for model.weight: copying a param with shape torch.Size([8, 4]) from checkpoint, the shape in current model is torch.Size([4, 2]).
        size mismatch for model.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([4]).

### Environment

<details>
  <summary>Current environment</summary>

- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule

- PyTorch Lightning Version (e.g., 1.5.0): 2.0.9

- Lightning App Version (e.g., 0.5.2):

- PyTorch Version (e.g., 2.0): 1.13.1

- Python version (e.g., 3.9): 3.10.8

- OS (e.g., Linux): Linux

- CUDA/cuDNN version:

- GPU models and configuration:

- How you installed Lightning(conda, pip, source): pip

- Running environment of LightningApp (e.g. local, cloud): local



</details>

### More info

Using load_state_dict fixes all of these issues
B-lanc commented 1 year ago

A mistake on the last 2 lines of "how to reproduce the bug" currently, it would still error out, but it won't error out when

~~~
model = a(True)

trainer = L.Trainer()
trainer.strategy.connect(model)
trainer.save_checkpoint("bugtesting.ckpt")
model2 = a(False)
model2.load_from_checkpoint("bugtesting.ckpt") #This should error out, but it doesn't.....
awaelchli commented 1 year ago

@B-lanc You are running into this issue here: https://github.com/Lightning-AI/lightning/issues/18169

You are calling load_from_checkpoint on an instance, but this is not the intended use. Change your code to this:

model = a.load_from_checkpoint("bugtesting.ckpt", testing=False)

The a here is the class (consider following standard practice to name classes with upper case) and the load_from_checkpoint is considered a classmethod.

Docs: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#load-from-checkpoint

B-lanc commented 1 year ago

Sorry for the lowercase classname, I just used it for testing this "bug" I see, thanks for the help! I couldn't find anything about supplying the model arguments into the load_from_checkpoint from the docs.