Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

DeepSpeed Zero 2 Fails to Load All Checkpoint Parameters #15694

Closed kelvins64 closed 1 year ago

kelvins64 commented 1 year ago

Bug description

Using DeepSpeed Zero 2 with certain models fails to properly save and reload the model checkpoint after conversion to the Lightning format.

In the provided example, several parameters do not appear in the param_shapes value of the Zero checkpoint (which the generated reconstruction script uses to build the state dict), despite appearing in the module value of the Zero checkpoint.

How to reproduce the bug

import os
from typing import Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
import argparse
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from transformers import MBartForConditionalGeneration, MBart50Tokenizer

class TextDataset(Dataset):
    def __init__(self, model_name, length):
        self.model_name = model_name
        self.length = length

        self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
        self.data = self.tokenizer(
            [f'Hello world {i}!' for i in range(length)],
            padding='longest',
            truncation=True,
            return_tensors='pt'
        )

    def __getitem__(self, index):
        return {
            'input_ids': self.data['input_ids'][index],
            'attention_mask': self.data['attention_mask'][index],
            'labels': self.data['input_ids'][index] # Have the target text be the input text
        }

    def __len__(self):
        return self.length

class BoringModel(LightningModule):
    def __init__(self, model_name):
        super().__init__()
        self.model = MBartForConditionalGeneration.from_pretrained(model_name)

    def forward(self, batch):
        return self.model(**batch)[0] # Return loss

    def training_step(self, batch, batch_idx):
        loss = self(batch)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch)
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

# Start new code
def run(str_args: Union[str, None] = None):
    # Parse args
    model_name = 'facebook/mbart-large-50'

    pl.seed_everything(42)
    parser = argparse.ArgumentParser()
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args() if str_args is None else parser.parse_args(str_args.split())

    # Build data
    train_data = DataLoader(TextDataset(model_name, 64), batch_size=2)
    val_data = DataLoader(TextDataset(model_name, 64), batch_size=2)

    # Build and train model
    model = BoringModel(model_name)

    trainer = Trainer.from_argparse_args(
        args,
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        deterministic=True
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

    # Convert checkpoint
    if trainer.is_global_zero:
        zero_ckpt_dir = os.path.join(os.getcwd(), 'lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt')
        ckpt_path = zero_ckpt_dir[:-4] + 'pth'
        convert_zero_checkpoint_to_fp32_state_dict(zero_ckpt_dir, ckpt_path)

        # Attempt to load checkpoint
        model.load_from_checkpoint(ckpt_path, model_name=model_name)

if __name__ == "__main__":
    run('--accelerator gpu --devices 0,1 --strategy deepspeed_stage_2')

Error messages and logs

Running the above code, we encounter the error message

Error(s) in loading state_dict for BoringModel:
    Missing key(s) in state_dict: "model.model.encoder.embed_tokens.weight", "model.model.decoder.embed_tokens.weight", "model.lm_head.weight"

Environment

#- PyTorch Lightning Version (e.g., 1.5.0): 1.8.1
#- PyTorch Version (e.g., 1.10): 1.13.0
#- Python version (e.g., 3.9): 3.10.4
#- Transformers version: 4.24.0
#- Deepspeed version: 0.7.5
#- OS (e.g., Linux): Linux

More info

cc @awaelchli

yakazimir commented 1 year ago

Any updates on this, or quick workarounds?

kelvins64 commented 1 year ago

Any updates on this, or quick workarounds?

You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint.

I can't guarantee that the parameters which aren't processed by the convert_zero_checkpoint_to_fp32_state_dict method are correct, since I don't know the details of DeepSpeed --> fp32 conversion. In practice, though, my checkpoints loaded using this method haven't run into any issues.

import os
import torch
from pytorch_lightning.utilities.deepspeed import (
        convert_zero_checkpoint_to_fp32_state_dict,
        get_model_state_file,
        get_optim_files,
        ds_checkpoint_dir
)

DS_PARAM_REGEX = r'_forward_module\.(.+)'

def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
    '''
    Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
    in parameters which are improperly loaded by the DeepSpeed conversion utility.
    deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
    pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
        placed in the same directory as the DeepSpeed checkpoint directory with the same name but
        a .pt extension.
    Returns: path to the converted checkpoint.
    '''
    if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
        raise ValueError(
            'args.ckpt_dir should point to the checkpoint directory'
            ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
        )

    # Convert state dict to PyTorch format
    if not pl_ckpt_path:
        pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt

    if not os.path.exists(pl_ckpt_path):
        convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)

    # Patch in missing parameters that failed to be converted by DeepSpeed utility
    pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
    torch.save(pl_ckpt, pl_ckpt_path)

    return pl_ckpt_path

def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
    '''
    Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
    into the fp32 state dict.
    deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
    fp32_ckpt_path: Path to the reconstructed
    '''
    # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
    checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
    optim_files = get_optim_files(checkpoint_dir)
    optim_state = torch.load(optim_files[0], map_location='cpu')
    zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
    deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)

    # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
    ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
    ds_sd = ds_ckpt['module']

    fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
    fp32_sd = fp32_ckpt['state_dict']

    for k, v in ds_sd.items():
        try:
            match = re.match(DS_PARAM_REGEX, k)
            param_name = match.group(1)
        except:
            print(f'Failed to extract parameter from DeepSpeed key {k}')
            continue

        v = v.to(torch.float32)
        if param_name not in fp32_sd:
            print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
            fp32_sd[param_name] = v
        else:
            assert torch.allclose(v, fp32_sd[param_name], atol=1e-2)

    return fp32_ckpt
yakazimir commented 1 year ago

thank you @kelvins64 , I will try this out.

francescocarzaniga commented 1 year ago

Any updates on this, or quick workarounds?

You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint.

I can't guarantee that the parameters which aren't processed by the convert_zero_checkpoint_to_fp32_state_dict method are correct, since I don't know the details of DeepSpeed --> fp32 conversion. In practice, though, my checkpoints loaded using this method haven't run into any issues.

import os
import torch
from pytorch_lightning.utilities.deepspeed import (
        convert_zero_checkpoint_to_fp32_state_dict,
        get_model_state_file,
        get_optim_files,
        ds_checkpoint_dir
)

DS_PARAM_REGEX = r'_forward_module\.(.+)'

def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
    '''
    Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
    in parameters which are improperly loaded by the DeepSpeed conversion utility.
    deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
    pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
        placed in the same directory as the DeepSpeed checkpoint directory with the same name but
        a .pt extension.
    Returns: path to the converted checkpoint.
    '''
    if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
        raise ValueError(
            'args.ckpt_dir should point to the checkpoint directory'
            ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
        )

    # Convert state dict to PyTorch format
    if not pl_ckpt_path:
        pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt

    if not os.path.exists(pl_ckpt_path):
        convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)

    # Patch in missing parameters that failed to be converted by DeepSpeed utility
    pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
    torch.save(pl_ckpt, pl_ckpt_path)

    return pl_ckpt_path

def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
    '''
    Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
    into the fp32 state dict.
    deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
    fp32_ckpt_path: Path to the reconstructed
    '''
    # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
    checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
    optim_files = get_optim_files(checkpoint_dir)
    optim_state = torch.load(optim_files[0], map_location='cpu')
    zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
    deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)

    # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
    ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
    ds_sd = ds_ckpt['module']

    fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
    fp32_sd = fp32_ckpt['state_dict']

    for k, v in ds_sd.items():
        try:
            match = re.match(DS_PARAM_REGEX, k)
            param_name = match.group(1)
        except:
            print(f'Failed to extract parameter from DeepSpeed key {k}')
            continue

        v = v.to(torch.float32)
        if param_name not in fp32_sd:
            print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
            fp32_sd[param_name] = v
        else:
            assert torch.allclose(v, fp32_sd[param_name], atol=1e-2)

    return fp32_ckpt

I can confirm this does work (though it's missing an import re at the beginning).

tbright17 commented 1 year ago

Confirm this works. Great work. Thanks

awaelchli commented 1 year ago

Thanks @yakazimir for pointing me to this issue. I looked into it and found that the problem lies in DeepSpeed. When saving a checkpoint, DeepSpeed is not able to identify the shared parameters and when converting/loading the checkpoint, it doesn't reconstruct them properly, leading to the error for missing keys.

I boiled this down to a reproducible script with DeepSpeed and submitted a ticket and a PR with the fix. If my PR gets merged, the workaround posted here won't be necessary anymore.

For reference, my investigation was done with DeepSpeed master (0.9.5dev) and Lightning master (2.1.0dev) starting from this script based on the original submission but with minor modifications to fit the newer API:

import os

import torch
import shutil
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import lightning.pytorch as pl
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from transformers import MBartForConditionalGeneration, MBart50Tokenizer

class TextDataset(Dataset):
    def __init__(self, model_name, length):
        self.model_name = model_name
        self.length = length

        self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
        self.data = self.tokenizer(
            [f'Hello world {i}!' for i in range(length)],
            padding='longest',
            truncation=True,
            return_tensors='pt'
        )

    def __getitem__(self, index):
        return {
            'input_ids': self.data['input_ids'][index],
            'attention_mask': self.data['attention_mask'][index],
            'labels': self.data['input_ids'][index] # Have the target text be the input text
        }

    def __len__(self):
        return self.length

class BoringModel(LightningModule):
    def __init__(self, model_name):
        super().__init__()
        self.model = MBartForConditionalGeneration.from_pretrained(model_name)

    def forward(self, batch):
        return self.model(**batch)[0] # Return loss

    def training_step(self, batch, batch_idx):
        loss = self(batch)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch)
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

def run():
    if os.path.exists("lightning_logs"):
        shutil.rmtree("lightning_logs")

    model_name = 'facebook/mbart-large-50'
    pl.seed_everything(42)

    train_data = DataLoader(TextDataset(model_name, 64), batch_size=2)
    val_data = DataLoader(TextDataset(model_name, 64), batch_size=2)

    model = BoringModel(model_name)
    trainer = Trainer(
        accelerator="cuda",
        devices=2,
        strategy="deepspeed_stage_2",
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        deterministic=True
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

    from pprint import pprint

    # Convert checkpoint
    if trainer.is_global_zero:
        pprint(trainer.strategy.config)

        zero_ckpt_dir = os.path.join(os.getcwd(), 'lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt')
        ckpt_path = zero_ckpt_dir[:-4] + 'pth'
        convert_zero_checkpoint_to_fp32_state_dict(zero_ckpt_dir, ckpt_path)

        # Attempt to load checkpoint
        model.load_from_checkpoint(ckpt_path, model_name=model_name, strict=True)

if __name__ == "__main__":
    run()
awaelchli commented 1 year ago

Fix was merged in deepspeed: https://github.com/microsoft/DeepSpeed/pull/3825

luccachiang commented 8 months ago

Any updates on this, or quick workarounds?

You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint. I can't guarantee that the parameters which aren't processed by the convert_zero_checkpoint_to_fp32_state_dict method are correct, since I don't know the details of DeepSpeed --> fp32 conversion. In practice, though, my checkpoints loaded using this method haven't run into any issues.

import os
import torch
from pytorch_lightning.utilities.deepspeed import (
        convert_zero_checkpoint_to_fp32_state_dict,
        get_model_state_file,
        get_optim_files,
        ds_checkpoint_dir
)

DS_PARAM_REGEX = r'_forward_module\.(.+)'

def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
    '''
    Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
    in parameters which are improperly loaded by the DeepSpeed conversion utility.
    deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
    pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
        placed in the same directory as the DeepSpeed checkpoint directory with the same name but
        a .pt extension.
    Returns: path to the converted checkpoint.
    '''
    if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
        raise ValueError(
            'args.ckpt_dir should point to the checkpoint directory'
            ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
        )

    # Convert state dict to PyTorch format
    if not pl_ckpt_path:
        pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt

    if not os.path.exists(pl_ckpt_path):
        convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)

    # Patch in missing parameters that failed to be converted by DeepSpeed utility
    pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
    torch.save(pl_ckpt, pl_ckpt_path)

    return pl_ckpt_path

def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
    '''
    Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
    into the fp32 state dict.
    deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
    fp32_ckpt_path: Path to the reconstructed
    '''
    # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
    checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
    optim_files = get_optim_files(checkpoint_dir)
    optim_state = torch.load(optim_files[0], map_location='cpu')
    zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
    deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)

    # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
    ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
    ds_sd = ds_ckpt['module']

    fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
    fp32_sd = fp32_ckpt['state_dict']

    for k, v in ds_sd.items():
        try:
            match = re.match(DS_PARAM_REGEX, k)
            param_name = match.group(1)
        except:
            print(f'Failed to extract parameter from DeepSpeed key {k}')
            continue

        v = v.to(torch.float32)
        if param_name not in fp32_sd:
            print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
            fp32_sd[param_name] = v
        else:
            assert torch.allclose(v, fp32_sd[param_name], atol=1e-2)

    return fp32_ckpt

I can confirm this does work (though it's missing an import re at the beginning).

I have to change DS_PARAM_REGEX = r'_forward_module\.(.+)' into DS_PARAM_REGEX = r'module\.(.+)', but this is great!