Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.18k stars 78 forks source link

Models trained with FSDP + Thunder doesn't work with litgpt chat #895

Open mpatel31415 opened 3 months ago

mpatel31415 commented 3 months ago

I was able to train Llama3-8b model with Thunder for a few steps and then save it. However when I try to use later litgpt generate or litgpt chat with the saved checkpoint I get an error about size mismatch. When I run the training in Eager mode everything works.

πŸ› Bug

To Reproduce

  1. Please extract this archive and put all the files into selected directory (let's call it CHECKPOINT_DIR) Meta-Llama-3-8B-tuned.zip . Here is the license.

    These are Llama-3B configuration files (no weights), they can be also downloaded by running: litgpt download meta-llama/Meta-Llama-3-8B

  2. Copy the benchmarking script from this repo located here thunder/benchmarks/benchmark_litgpt.py and add model saving in line 622:

      torch_dist.barrier()
      states = benchmark.model.state_dict()
      if global_rank == 0:
          torch.save(states, "/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned/lit_model.pth")

To be sure that version of the script is the same, I'm also attaching the full, modified file (it's python code, but I can add only txt files here): benchmark_litgpt.txt

Let's assume it's located in SCRIPT_DIR directory.

  1. Start docker container on a node with 8xH100:
    docker run --pull=always --gpus all --ipc=host --ulimit \
    memlock=-1 --ulimit stack=67108864 -it \
    -v ${CHECKPOINT_DIR}:/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned \
    -v ${SCRIPT_DIR}:/repro
    INTERNAL_IMAGE:nvidia internal container from 20240731
  2. Install recent litgpt version:
    python -m pip install litgpt==0.4.5

    For Eager

5E. Run training for Eager (on dummy data so output won't make sense, but it's easier to run the reproduction instructions)

 torchrun --standalone --max-restarts=0 --nproc-per-node=8  /repro/benchmark_litgpt.py  --model_name Llama-3-8B --max_iters 10 --warmup_iters 2 --distributed_mode fsdp --shard_mode zero3 --bucketing_mode block

You should see new file lit_model.pth in checkpoint directory.

6E. Try to chat with the saved model:

litgpt chat /lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned

It should run but return garbage.

For Thunder

5T. You can remove the lit_model.pth (but it will be overwritten anyway) and then run:

torchrun --standalone --max-restarts=0 --nproc-per-node=8  /repro/benchmark_litgpt.py  --model_name Llama-3-8B --max_iters 10 --warmup_iters 2 --distributed_mode fsdp --shard_mode zero3 --bucketing_mode block --compile thunder

6T. Try to chat with the saved model:

litgpt chat /lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned

There is an error:

{'access_token': None, 'checkpoint_dir': PosixPath('/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned'), 'compile': False, 'max_new_tokens': 50, 'multiline': False, 'precision': None, 'quantize': None, 'temperature': 0.8, 'top_k': 200, 'top_p': 1.0} Traceback (most recent call last): File "/usr/local/bin/litgpt", line 8, in sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/litgpt/main.py", line 71, in main CLI(parser_data) File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 119, in CLI return _run_component(component, init.get(subcommand)) File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 204, in _run_component return component(*cfg) File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/litgpt/chat/base.py", line 258, in main load_checkpoint(fabric, model, checkpoint_path) File "/usr/local/lib/python3.10/dist-packages/litgpt/utils.py", line 362, in load_checkpoint model.load_state_dict(state_dict, strict=strict) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2542, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for GPT: size mismatch for lm_head.weight: copying a param with shape torch.Size([16032, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]). size mismatch for transformer.wte.weight: copying a param with shape torch.Size([16032, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]). size mismatch for transformer.h.0.norm_1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([4096]). ...

Complete output

Expected behavior

We should be able to run model trained with Thunder with litgpt instructions.

Environment

nvidia-smi output: image

Version of packages:

lightning-thunder 0.2.0.dev0 /opt/pytorch/lightning-thunder lightning-utilities 0.11.6 litgpt 0.4.5 nvfuser 0.2.8+gitaf62096 /opt/pytorch/nvfuser pytorch-lightning 2.3.3 torch 2.5.0a0+git83db609 torchmetrics 1.4.0.post0 torchvision 0.19.0a0+d23a6e1

crcrpar commented 3 months ago

https://github.com/Lightning-AI/lightning-thunder/issues/564 could be related

mpatel31415 commented 3 months ago

FYI: I was curious if the code to save checkpoint is correct in Eager mode for sure, so I used it on each rank and then compared the shapes of parameters from state_dict with the original (lit_model.pth) model, before it was wrapped with FSDP and values between ranks (to check if they were synchronized) . And it seems that both shapes and values are equal.

mpatel31415 commented 3 months ago

Small update after discussion with @carmocca about saving checkpoints from Thunder FSDP:

I tried to use save and get_model_state_dict functions provided by Thunder and then convert checkpoint into torch save checkpoint using dcp_to_torch_save, but I also get shape error when later trying to use the output with litgpt chat.

Below is the code I used (I should be possible to copy it instead of the code provided in the original description):

from thunder.distributed.checkpoint import save, get_model_state_dict, StateDictOptions
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save

options = StateDictOptions(full_state_dict=False, cpu_offload=False)
state_dict = get_model_state_dict(model, options, rank)
dcp_path = "/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned/distributed_ckp"
save(state_dict, dcp_path)
torch_dist.barrier()
if rank == 0:
    dcp_to_torch_save(dcp_path,  "/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned/lit_model.pth")

The only option that could make it work now is to train the model with Fabric FSDP, but I haven't tested it yet.

tfogal commented 3 months ago

triage review:

mpatel31415 commented 2 months ago

Hi! Is there any update about this? From the Slack discussion and my understanding there were 3 options for me to progress:

  1. Save distributed checkpoint in Thunder, convert it to "torch save" checkpoint using Pytorch function (it should be possible because the distributed Thunder checkpoint is expected to be equivalent to distributed Pytorch checkpoint) and use the resulting "torch save" checkpoint in LitGPT chat. This option has also failed. Should I wait for it to be resolved?
  2. Train with Fabric and everything should work, but do we expect that Thunder will work with LitGPT only when Fabric is used? If so or this is the best solution for now I can change the code of the demo to use Fabric.
  3. Write my own "chat" script loading distributed Thunder checkpoint.

Please let me know which direction is the best to follow from your perspective.