Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
607 stars 77 forks source link

HFSaveCheckpoint does not work with deepspeed #273

Open jessecambon opened 1 year ago

jessecambon commented 1 year ago

🐛 Bug

HFSaveCheckpoint does not save a HuggingFace checkpoint when the model is trained with deepspeed. No message or warning appears to indicate that the HF checkpoint did not save.

To Reproduce

Use the HFSaveCheckpoint callback when training with deepspeed. I encountered this on both a multinode (Azure) and a single node (local) environment.

Code sample

import os
from typing import Any, List, Optional, Dict
from pytorch_lightning.utilities.types import _PATH

from pytorch_lightning import Trainer
from transformers import AutoTokenizer
from pytorch_lightning import seed_everything

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)
from lightning_transformers.plugins.checkpoint import HFSaveCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint

model_arch="prajjwal1/bert-tiny"

seed_everything(102938, workers = True)
tokenizer = AutoTokenizer.from_pretrained(model_arch)

dm = TextClassificationDataModule(
        tokenizer = tokenizer,
        train_val_split = 0.01, # Split used for creating a validation set out of the training set
        dataset_name = "glue",
        dataset_config_name = "cola",
        batch_size=1,
        num_workers=os.cpu_count(),
        padding="max_length",
        truncation=True,
        max_length=512
    )

model = TextClassificationTransformer(
    pretrained_model_name_or_path=model_arch
    )

checkpoint_callback = ModelCheckpoint(save_top_k=1,monitor="val_loss")

trainer = Trainer(
accelerator='gpu',
plugins=HFSaveCheckpoint(model=model),
callbacks=[checkpoint_callback],
logger=False,
enable_checkpointing=True,
log_every_n_steps=10,
limit_train_batches=30,
limit_val_batches=10,
max_epochs=2,
strategy='deepspeed_stage_3'
) 

trainer.fit(model,dm)

print(f"Best model path: {checkpoint_callback.best_model_path}")

Expected behavior

Either a warning is thrown or the HF model saves properly.

Environment

Lightning transformers 0.2.1

* CUDA:
        - GPU:
                - Quadro T2000 with Max-Q Design
        - available:         True
        - version:           11.3
* Packages:
        - numpy:             1.22.3
        - pyTorch_debug:     False
        - pyTorch_version:   1.11.0
        - pytorch-lightning: 1.6.4
        - tqdm:              4.64.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.13
        - version:           #44-Ubuntu SMP Wed Jun 22 14:20:53 UTC 2022
SeanNaren commented 1 year ago

Thanks for the report!

This is tricky as when using deepspeed stage 3, checkpoints are saved as shards as opposed to a single checkpoint.

One thing you might benefit from is saving normally using lightning, then using https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/deepspeed.py#L52 to combine the checkpoints into one. After this, load the checkpoints into theTextClassificationTransformer and call the save_hf_checkpoint function.

It would be nice to have an automated system to do this, but I'm worried that adding too much automation will cause overhead. Let me know if this solution works for you in the meantime!

jessecambon commented 1 year ago

@SeanNaren, thanks that solution worked nicely. I added this code to the end of the script in the original post:

from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

aggregated_checkpoint_path="checkpoints/aggregated"
hf_checkpoint_path="checkpoints/hf_checkpoint"

# Convert sharded checkpoint files into a single checkpoint file
#https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/deepspeed.py#L52
convert_zero_checkpoint_to_fp32_state_dict(
    checkpoint_callback.best_model_path,
    aggregated_checkpoint_path
 )

# Load best model from aggregated checkpoint file
best_model = TextClassificationTransformer.load_from_checkpoint(
    aggregated_checkpoint_path
)

# Save model and tokenizer to HF checkpoint
best_model.model.save_pretrained(hf_checkpoint_path)
tokenizer.save_pretrained(hf_checkpoint_path)
jessecambon commented 1 year ago

@SeanNaren one follow up on this - although the code above worked for a local environment (1 gpu on 1 node), I run into this error when convert_zero_checkpoint_to_fp32_state_dict runs on a 2 node, 2 gpu per node setup on Azure:

ValueError: Expected 4 of '*_optim_states.pt' under <path/to/dir> but found 2 files. Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes.
total 36

It looks like it is only seeing the checkpoint files from one of the two nodes. I tried inserting trainer.strategy.barrier() immediately after trainer.fit() and only running the checkpoint conversion code on the root process (trainer.is_global_zero == True), but this didn't fix the issue.

Is there a way to run the checkpoint convert function on a multinode environment or should I just be running it in a separate script/process?

jessecambon commented 1 year ago

There is probably a more elegant solution, but what ended up working for me was to have the root process of each node upload the files to Azure immediately after trainer.fit():

# Run one upload process per node to make sure we capture all the sharded checkpoint files
if os.environ.get("LOCAL_RANK") == "0":
    print(f"Uploading best model checkpoints (sharded) to {checkpoints_upload_path}")
    datastore.upload(
                    src_dir=checkpoint_callback.best_model_path,
                    target_path=checkpoints_upload_path,
                    overwrite=True
                    )

In a separate process I can then download all these sharded checkpoint files from Azure and run the code above to convert them to a single checkpoint file.

omerarshad commented 1 year ago

getting this error while loading a language model

RuntimeError: Error(s) in loading state_dict for LanguageModelingTransformer:
    Missing key(s) in state_dict: "model.lm_head.weight". 

any leads?