Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.3k stars 3.38k forks source link

Differences in layer ordering causes errors when resuming from checkpoints #17025

Open Erotemic opened 1 year ago

Erotemic commented 1 year ago

Bug description

When resuming from a checkpoint (or multi-gpu training with ddp, anything where you might construct multiple instances of your model), if the model layers are constructed in different orders, it will cause the optimizer stats to have parameters in different orders, which will cause a size mismatch error. In the cases where the sizes agree this error will go unnoticed (except with perhaps less optimal training dynamics). When the sizes disagree, then it explicitly errors.

Not sure if this is really a torch issue or a lightning issue. I'm submitting here first.

How to reproduce the bug

I have a MWE that reproduces the issue:

https://gist.github.com/Erotemic/dfaadf5cf9fa4910beb901ae6c93867b

It's about 400 lines that try to capture both the error and my use case that uncovered the error. I had to use this to even figure out what was going on, so I'm sure the example could be more minimal. I've highlighted where the part that causes the error. If you sort the set to make construction consistent the error goes away - so there is a user-side fix, but I'm not sure if it also warents a torch or lightning side fix.

The idea is we have multiple image sensors that observe multiple different bands at different resolutions and we are putting them all into a single network. I define a ModuleDict of stems to normalize the number of channels in each type of input so I can concatenate them into tokens for a transformer. The issue was that I was iterating over a set to construct entries in the ModuleDict, and that is non-determenistic between python runs. It does seem strange that either lightning / torch can't deal with this though. They modules are keyed, so you should be able to order them.

Error messages and logs

Epoch 1:  14%|████████████▊                                                                                 | 15/110 [00:00<00:06, 15.63it/s, loss=0.894, v_num=4, train_loss=0.955]Traceback (most recent call last):
  File "/home/joncrall/code/watch/dev/mwe/lightning_cli_ckpt_path_error.py", line 437, in <module>
    main()
  File "/home/joncrall/code/watch/dev/mwe/lightning_cli_ckpt_path_error.py", line 380, in main
    MWE_LightningCLI(

...

    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 119, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/torch/optim/adam.py", line 234, in step
    adam(params_with_grad,
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/torch/optim/adam.py", line 300, in adam
    func(params,
  File "/home/joncrall/.pyenv/versions/3.10.10/envs/pyenv3.10.10/lib/python3.10/site-packages/torch/optim/adam.py", line 363, in _single_tensor_adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

Environment

Current environment ``` (pyenv3.10.10) joncrall@toothbrush:~/code/watch/dev/mwe$ pyversion pytorch_lightning python -c "import pytorch_lightning; print('pytorch_lightning.__version__ = ' + str(pytorch_lightning.__version__))" pytorch_lightning.__version__ = 1.9.4 (pyenv3.10.10) joncrall@toothbrush:~/code/watch/dev/mwe$ pyversion torch python -c "import torch; print('torch.__version__ = ' + str(torch.__version__))" torch.__version__ = 1.13.1+cu117 (pyenv3.10.10) joncrall@toothbrush:~/code/watch/dev/mwe$ python --version Python 3.10.10 ```

More info

No response

awaelchli commented 1 year ago

Hey @Erotemic

Is the issue fixed if you simply make the self.knonw_sensorchan a list instead of a set? I can see that your known_modalities is already on a list.

The issue you describe with non-deterministic iteration over a hash set in Python is normal. Sets are unordered and thus we can't expect them to have a "natural" order (the hashing is based on a seed which gets randomized every Python run).

Erotemic commented 1 year ago

Yes. You can work around this issue by ensuring that you are always creating things in the same order.

But I'm wondering if the user should need to worry about that? Ideally if a user did something like this in a model:

self.layers = torch.nn.ModuleDict()
if ascending:
    self.layers['layer1'] = torch.nn.Conv2d(3, 5, 1, 1)
    self.layers['layer2'] = torch.nn.Conv2d(5, 7, 1, 1)
else:
    self.layers['layer2'] = torch.nn.Conv2d(5, 7, 1, 1)
    self.layers['layer1'] = torch.nn.Conv2d(3, 5, 1, 1)

And they did a run with ascending=True and then restarted with ckpt_path but also set ascending=False lightning (or torch) would be able to compensate for the change in ordering - especially because the keys still uniquely identifiy which layer is which.

Granted this is a contrived example, but consider a the case where a user is building such a module dictionary and they use a set to remove duplicates. (This is basically how I got myself into this mess). I'm aware of the arbitrary set ordering, but I didn't think that the order in which I defined things would be important. Nor does it seem like it should be.

Perhaps users just need to be careful when defining their models, so at the very least this is a gotcha. I would probably still classify this as a bug, but its certainly non-critical.

Here is a more minimal example that reproduces the contrived example:

import torch
import torch.nn
import pytorch_lightning as pl
from torch.utils.data import Dataset
from pytorch_lightning.cli import LightningCLI

class SimpleModel(pl.LightningModule):
    def __init__(self, ascending=False):
        super().__init__()
        self.layers = torch.nn.ModuleDict()
        if ascending:
            self.layers['layer1'] = torch.nn.Conv2d(3, 5, 1, 1)
            self.layers['layer2'] = torch.nn.Conv2d(5, 7, 1, 1)
        else:
            self.layers['layer2'] = torch.nn.Conv2d(5, 7, 1, 1)
            self.layers['layer1'] = torch.nn.Conv2d(3, 5, 1, 1)

    def forward(self, inputs):
        x = inputs
        x = self.layers['layer1'](x)
        x = self.layers['layer2'](x)
        return x

    def forward_step(self, batch):
        """
        Generic forward step used for test / train / validation
        """
        batch = torch.stack(batch, dim=0)
        x = self.forward(batch)
        loss = x.sum()
        return loss

    def training_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch)
        return outputs

    def validation_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch)
        return outputs

class SimpleDataset(Dataset):
    def __len__(self):
        return 100

    def __getitem__(self, index):
        return torch.rand(3, 10, 10)

class SimpleDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=1, num_workers=0):
        super().__init__()
        self.save_hyperparameters()

    def setup(self, stage):
        self.train_dataset = SimpleDataset()
        self.vali_dataset = SimpleDataset()

    def train_dataloader(self):
        return self._make_dataloader(self.train_dataset, shuffle=True)

    def val_dataloader(self):
        return self._make_dataloader(self.vali_dataset, shuffle=False)

    def _make_dataloader(self, dataset, shuffle=False):
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=shuffle, pin_memory=True,
            collate_fn=lambda x: x
        )
        return loader

def main():
    LightningCLI(
        model_class=SimpleModel,
        datamodule_class=SimpleDataModule,
    )

if __name__ == '__main__':
    """
    CommandLine:
        cd ~/code/watch/dev/mwe/

        DEFAULT_ROOT_DIR=./mwe_train_dir

        python lightning_ckpt_mwe.py fit --config "
            model:
                ascending: True
            data:
                num_workers: 8
                batch_size: 2
            optimizer:
              class_path: torch.optim.Adam
              init_args:
                lr: 1e-7
            trainer:
              default_root_dir     : $DEFAULT_ROOT_DIR
              accelerator          : gpu
              devices              : 0,
              max_epochs: 100
        "

        CKPT_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/checkpoints/*.ckpt'))[-1])")
        echo "CKPT_FPATH = $CKPT_FPATH"

        # Even though the model is "the same", the ordering of layers is different and
        # and that causes an error
        python lightning_ckpt_mwe.py fit --config "
            model:
                ascending: False
            data:
                num_workers: 8
                batch_size: 2
            optimizer:
              class_path: torch.optim.Adam
              init_args:
                lr: 1e-7
            trainer:
              default_root_dir     : $DEFAULT_ROOT_DIR
              accelerator          : gpu
              devices              : 0,
              max_epochs: 100
        " --ckpt_path="$CKPT_FPATH"
    """
    main()

Lastly, I do want to highlight that if instead of chosing my in/out channels as 3, 5, and 7 and I just used 3 everywhere, the weights would just fit and the user would not be presented with any error. Lightning would happilly load the checkpoint's optimizer momentums into the wrong layers and training would "resume" from this borked state, at worst causing a poorly fit model and at best silently sabatoging reproducibility. But again, it's a niche case that a user can avoid if they are aware of it.

awaelchli commented 1 year ago

Perhaps users just need to be careful when defining their models, so at the very least this is a gotcha. I would probably still classify this as a bug, but its certainly non-critical.

Lightning doesn't have any logic that would change the way sets get hashed. Lightning cannot possibly enforce a fixed ordering. I cannot see how this could possibly be a bug in Lightning or how Lightning could "fix" this.

I think that if a user is making use of sets, but then at the same time needs to rely on a fixed ordering, then they are using sets wrong.

I think in your contrived example, you should switch the two layers, otherwise I don't think anything would change:

self.layers = torch.nn.ModuleDict()
if ascending:
    self.layers['layer1'] = torch.nn.Conv2d(3, 5, 1, 1)
    self.layers['layer2'] = torch.nn.Conv2d(5, 7, 1, 1)
else:
    self.layers['layer2'] = torch.nn.Conv2d(3, 5, 1, 1)
    self.layers['layer1'] = torch.nn.Conv2d(5, 7, 1, 1)

You are basically saying that you train the model with ascending=True, save a checkpoint, then make a new model with ascending = False. Now, when loading the weights, you expect that Lightning can load the weights by magically renaming the keys?

Erotemic commented 1 year ago

I think in your contrived example, you should switch the two layers, otherwise I don't think anything would change:

It does change, That's the point!

The state dict (in terms of key/values) is the same. The state dict in the checkpoint is loaded in just fine. It's the optimizer state that is loaded in a way that it depends on the order in which the keys are inserted into the module state dictionary, even though the key/value pairs are unchanged (I'm actually not 100% sure on the key-insertion explination, but something in the optimizer depends only on ordering and not the key names). Running the CommandLine section in the MWE demonstrates this.

I think that if a user is making use of sets, but then at the same time needs to rely on a fixed ordering, then they are using sets wrong.

Yes, but the issue I'm trying to highlight here is that it is non-obvious that anything here should be order dependant. When I realized what it was, I fixed the issue on my side. I raise the issue here because I think there might be something in the way lightning is saving / loading optimizer state that might be changed in order to remove the ordering dependency.

Perhaps I need to do some more work to find a more targeted example that demonstrates the difference wrt to how optimizer state is loaded; not sure when I will be able to get to that. In the meantime, what I've noticed is that if I embed into: File "torch/optim/adam.py", line 363, in _single_tensor_adam and look at the shapes of params, grads, exp_avgs, etc...

The exp_avgs are loaded incorrectly whereas grads and params are loaded correctly.

In [2]:     print([p.shape for p in params])
   ...:     print([p.shape for p in grads])
   ...:     print([p.shape for p in exp_avgs])
   ...: 
[torch.Size([7, 5, 1, 1]), torch.Size([7]), torch.Size([5, 3, 1, 1]), torch.Size([5])]
[torch.Size([7, 5, 1, 1]), torch.Size([7]), torch.Size([5, 3, 1, 1]), torch.Size([5])]
[torch.Size([5, 3, 1, 1]), torch.Size([5]), torch.Size([7, 5, 1, 1]), torch.Size([7])]

I can probably make this more obvious by hooking into the on checkpoint load callback.

awaelchli commented 1 year ago

@Erotemic Do you have a concrete suggestion how this can be "detected" in Lightning? I am not convinced that this can be addressed in Lightning at all. I am also not convinced that this can be addressed in PyTorch itself.

Erotemic commented 1 year ago

Concrete... unfortunately no. I would have submitted a PR if I knew how to fix it.

I'm not convinced its fixable either, at least with the current ways that optimizers save their states, which I believe (correct me if I'm wrong) is entirely order based. I do have less concrete ideas: I think one would need to make an association between layer names used in the model state dictionary and the entries in the optimizer dictionary. Perhaps that could be achieved in Strategy.optimizer_state, but it might be a heavy lift. Thinking out loud, perhaps a named tuple might provide a backwards comptible mechanism that associates a layer name with each entry but also presereves the order-only way of accessing state. Another possibility could just be augmenting the state with an additional field that records the associated layer names at save time, so at load time they could be checked if they are in the same order and corrected.

At the very least now that the issue exists the bug should be more searchable than it previously was, and hopfully it helps someone out.

Erotemic commented 1 year ago

@awaelchli I did come up with a concrete way of detecting this in the case of the AdamW optimizer. Unfortuantely, it doesn't look like the information needed to fix it is available.

In load_optimizer_state_dict if I add the following code:

            from rich import print
            model_state = self.model.state_dict()
            optim_state = optimizer_states[0]['state']

            new_named_model_params = list(self.model.named_parameters())

            for layer_index, layer_optim_state in optim_state.items():
                print('--------------------------')
                new_layer_name, new_param = new_named_model_params[layer_index]
                optim_state_shape = {k: v.shape for k, v in layer_optim_state.items()}
                optim_state_val = layer_optim_state['exp_avg']
                is_ok = new_param.shape == optim_state_val.shape

                print(f'layer_index={layer_index}')
                print(f'new_layer_name={new_layer_name}')
                print(f'new_param.shape={new_param.shape}')
                print(f'optim_state_shape={optim_state_shape}')
                if is_ok:
                    print('[green] ok')
                else:
                    print('[red] not ok')

It will check the shape of the AdamW tensors versus the what should be the corresponding model parameter. I print "not ok" if there is a mismatch.

I think we can also use the ordering of checkpoint['state_dict'] to check what the old ordering was and correct for it. I used the following code and was able to resume from a checkpoint where my model had this error. The code is inefficient and ugly, but it is a concrete starting point for how this can be corrected inside of lightning (at least for AdamW, probably needs work to work for all optimizers).

            # Get the new ordering of the new parameter names
            new_named_model_params = list(self.model.named_parameters())
            new_param_keys = [k for k, v in new_named_model_params]

            # Get the old ordering of the old layer names
            old_model_state = checkpoint['state_dict']
            old_state_keys = list(old_model_state.keys())

            # Determine which of the old layers are params and their order
            old_param_keys = []
            for key in old_state_keys:
                if key in new_param_keys:
                    old_param_keys.append(key)

            old_to_new = {}
            for old_param_index, param_key in enumerate(old_param_keys):
                new_param_index = new_param_keys.index(param_key)
                old_to_new[old_param_index] = new_param_index

            fixed_optimzier_states = []
            for old_opt_state in optimizer_states:
                fixed_opt_state_ = []
                for old_param_index, layer_opt_state in old_opt_state['state'].items():
                    new_layer_index = old_to_new[old_param_index]
                    fixed_opt_state_.append((new_layer_index, layer_opt_state))
                fixed_opt_state = dict(sorted(fixed_opt_state_))
                new_opt_state = {
                    'state': fixed_opt_state,
                    'param_groups': old_opt_state['param_groups'],
                }
                fixed_optimzier_states.append(new_opt_state)
            optimizer_states = fixed_optimzier_states

Of course this would be so much easier if optimizers were just saved using layer names instead of integer indexes.