Closed kelvins64 closed 1 year ago
Any updates on this, or quick workarounds?
Any updates on this, or quick workarounds?
You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint.
I can't guarantee that the parameters which aren't processed by the convert_zero_checkpoint_to_fp32_state_dict
method are correct, since I don't know the details of DeepSpeed --> fp32 conversion. In practice, though, my checkpoints loaded using this method haven't run into any issues.
import os
import torch
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
get_model_state_file,
get_optim_files,
ds_checkpoint_dir
)
DS_PARAM_REGEX = r'_forward_module\.(.+)'
def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
'''
Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
in parameters which are improperly loaded by the DeepSpeed conversion utility.
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
placed in the same directory as the DeepSpeed checkpoint directory with the same name but
a .pt extension.
Returns: path to the converted checkpoint.
'''
if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
raise ValueError(
'args.ckpt_dir should point to the checkpoint directory'
' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
)
# Convert state dict to PyTorch format
if not pl_ckpt_path:
pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt
if not os.path.exists(pl_ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)
# Patch in missing parameters that failed to be converted by DeepSpeed utility
pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
torch.save(pl_ckpt, pl_ckpt_path)
return pl_ckpt_path
def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
'''
Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
into the fp32 state dict.
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
fp32_ckpt_path: Path to the reconstructed
'''
# This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
optim_files = get_optim_files(checkpoint_dir)
optim_state = torch.load(optim_files[0], map_location='cpu')
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)
# Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
ds_sd = ds_ckpt['module']
fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
fp32_sd = fp32_ckpt['state_dict']
for k, v in ds_sd.items():
try:
match = re.match(DS_PARAM_REGEX, k)
param_name = match.group(1)
except:
print(f'Failed to extract parameter from DeepSpeed key {k}')
continue
v = v.to(torch.float32)
if param_name not in fp32_sd:
print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
fp32_sd[param_name] = v
else:
assert torch.allclose(v, fp32_sd[param_name], atol=1e-2)
return fp32_ckpt
thank you @kelvins64 , I will try this out.
Any updates on this, or quick workarounds?
You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint.
I can't guarantee that the parameters which aren't processed by the
convert_zero_checkpoint_to_fp32_state_dict
method are correct, since I don't know the details of DeepSpeed --> fp32 conversion. In practice, though, my checkpoints loaded using this method haven't run into any issues.import os import torch from pytorch_lightning.utilities.deepspeed import ( convert_zero_checkpoint_to_fp32_state_dict, get_model_state_file, get_optim_files, ds_checkpoint_dir ) DS_PARAM_REGEX = r'_forward_module\.(.+)' def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None): ''' Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching in parameters which are improperly loaded by the DeepSpeed conversion utility. deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be placed in the same directory as the DeepSpeed checkpoint directory with the same name but a .pt extension. Returns: path to the converted checkpoint. ''' if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)): raise ValueError( 'args.ckpt_dir should point to the checkpoint directory' ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").' ) # Convert state dict to PyTorch format if not pl_ckpt_path: pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt if not os.path.exists(pl_ckpt_path): convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path) # Patch in missing parameters that failed to be converted by DeepSpeed utility pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path) torch.save(pl_ckpt, pl_ckpt_path) return pl_ckpt_path def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str): ''' Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint into the fp32 state dict. deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. fp32_ckpt_path: Path to the reconstructed ''' # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path) optim_files = get_optim_files(checkpoint_dir) optim_state = torch.load(optim_files[0], map_location='cpu') zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage) # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu') ds_sd = ds_ckpt['module'] fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu') fp32_sd = fp32_ckpt['state_dict'] for k, v in ds_sd.items(): try: match = re.match(DS_PARAM_REGEX, k) param_name = match.group(1) except: print(f'Failed to extract parameter from DeepSpeed key {k}') continue v = v.to(torch.float32) if param_name not in fp32_sd: print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd') fp32_sd[param_name] = v else: assert torch.allclose(v, fp32_sd[param_name], atol=1e-2) return fp32_ckpt
I can confirm this does work (though it's missing an import re
at the beginning).
Confirm this works. Great work. Thanks
Thanks @yakazimir for pointing me to this issue. I looked into it and found that the problem lies in DeepSpeed. When saving a checkpoint, DeepSpeed is not able to identify the shared parameters and when converting/loading the checkpoint, it doesn't reconstruct them properly, leading to the error for missing keys.
I boiled this down to a reproducible script with DeepSpeed and submitted a ticket and a PR with the fix. If my PR gets merged, the workaround posted here won't be necessary anymore.
For reference, my investigation was done with DeepSpeed master (0.9.5dev) and Lightning master (2.1.0dev) starting from this script based on the original submission but with minor modifications to fit the newer API:
import os
import torch
import shutil
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import lightning.pytorch as pl
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
class TextDataset(Dataset):
def __init__(self, model_name, length):
self.model_name = model_name
self.length = length
self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
self.data = self.tokenizer(
[f'Hello world {i}!' for i in range(length)],
padding='longest',
truncation=True,
return_tensors='pt'
)
def __getitem__(self, index):
return {
'input_ids': self.data['input_ids'][index],
'attention_mask': self.data['attention_mask'][index],
'labels': self.data['input_ids'][index] # Have the target text be the input text
}
def __len__(self):
return self.length
class BoringModel(LightningModule):
def __init__(self, model_name):
super().__init__()
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
def forward(self, batch):
return self.model(**batch)[0] # Return loss
def training_step(self, batch, batch_idx):
loss = self(batch)
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch)
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def run():
if os.path.exists("lightning_logs"):
shutil.rmtree("lightning_logs")
model_name = 'facebook/mbart-large-50'
pl.seed_everything(42)
train_data = DataLoader(TextDataset(model_name, 64), batch_size=2)
val_data = DataLoader(TextDataset(model_name, 64), batch_size=2)
model = BoringModel(model_name)
trainer = Trainer(
accelerator="cuda",
devices=2,
strategy="deepspeed_stage_2",
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
deterministic=True
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
from pprint import pprint
# Convert checkpoint
if trainer.is_global_zero:
pprint(trainer.strategy.config)
zero_ckpt_dir = os.path.join(os.getcwd(), 'lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt')
ckpt_path = zero_ckpt_dir[:-4] + 'pth'
convert_zero_checkpoint_to_fp32_state_dict(zero_ckpt_dir, ckpt_path)
# Attempt to load checkpoint
model.load_from_checkpoint(ckpt_path, model_name=model_name, strict=True)
if __name__ == "__main__":
run()
Fix was merged in deepspeed: https://github.com/microsoft/DeepSpeed/pull/3825
Any updates on this, or quick workarounds?
You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint. I can't guarantee that the parameters which aren't processed by the
convert_zero_checkpoint_to_fp32_state_dict
method are correct, since I don't know the details of DeepSpeed --> fp32 conversion. In practice, though, my checkpoints loaded using this method haven't run into any issues.import os import torch from pytorch_lightning.utilities.deepspeed import ( convert_zero_checkpoint_to_fp32_state_dict, get_model_state_file, get_optim_files, ds_checkpoint_dir ) DS_PARAM_REGEX = r'_forward_module\.(.+)' def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None): ''' Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching in parameters which are improperly loaded by the DeepSpeed conversion utility. deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be placed in the same directory as the DeepSpeed checkpoint directory with the same name but a .pt extension. Returns: path to the converted checkpoint. ''' if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)): raise ValueError( 'args.ckpt_dir should point to the checkpoint directory' ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").' ) # Convert state dict to PyTorch format if not pl_ckpt_path: pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt if not os.path.exists(pl_ckpt_path): convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path) # Patch in missing parameters that failed to be converted by DeepSpeed utility pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path) torch.save(pl_ckpt, pl_ckpt_path) return pl_ckpt_path def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str): ''' Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint into the fp32 state dict. deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. fp32_ckpt_path: Path to the reconstructed ''' # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path) optim_files = get_optim_files(checkpoint_dir) optim_state = torch.load(optim_files[0], map_location='cpu') zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage) # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu') ds_sd = ds_ckpt['module'] fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu') fp32_sd = fp32_ckpt['state_dict'] for k, v in ds_sd.items(): try: match = re.match(DS_PARAM_REGEX, k) param_name = match.group(1) except: print(f'Failed to extract parameter from DeepSpeed key {k}') continue v = v.to(torch.float32) if param_name not in fp32_sd: print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd') fp32_sd[param_name] = v else: assert torch.allclose(v, fp32_sd[param_name], atol=1e-2) return fp32_ckpt
I can confirm this does work (though it's missing an
import re
at the beginning).
I have to change DS_PARAM_REGEX = r'_forward_module\.(.+)'
into DS_PARAM_REGEX = r'module\.(.+)'
, but this is great!
Bug description
Using DeepSpeed Zero 2 with certain models fails to properly save and reload the model checkpoint after conversion to the Lightning format.
In the provided example, several parameters do not appear in the
param_shapes
value of the Zero checkpoint (which the generated reconstruction script uses to build the state dict), despite appearing in themodule
value of the Zero checkpoint.How to reproduce the bug
Error messages and logs
Running the above code, we encounter the error message
Environment
More info
cc @awaelchli