Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.
Apache License 2.0
5.99k stars 520 forks source link

How to use deepspeed zero-3-offload strategy correctly? (Parameters Duplication Issue) #84

Open KzZheng opened 1 year ago

KzZheng commented 1 year ago

Hi, I wonder how to write the code for using the deepspeed zero-3-offload strategy correctly. Currently, my code looks like:

from lightning.fabric.strategies import DeepSpeedStrategy
deep_speed = DeepSpeedStrategy(
                    stage=3,
                    offload_optimizer=True,
                    offload_parameters=True,
                )
fabric = L.Fabric(accelerator="gpu", devices=num_devices,strategy=deep_speed)

However, it seems the parameters are duplicated for all gpu. I attached the screenshot to show the GPU utilization after model, optimizer = fabric.setup(model, optimizer):

Selection_282

According to my understanding, the parameters should be distributed on different devices, right?

awaelchli commented 1 year ago

For zero-3 with deepspeed, you should add the context manager over the model initialization:

with fabric.sharded_model():
    model = ...

Perhaps you forgot this?

KzZheng commented 1 year ago

Thanks for your reply! Since I'm a beginner at using fabric and deepspeed, I'm not sure how to add this context manager correctly. Taking lit-llama as an example, should I write like this?

    with fabric.sharded_model():
        with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
            model = LLaMA(config)

        checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

        # strict=False because missing keys due to LoRA weights not contained in checkpoint state
        model.load_state_dict(checkpoint, strict=False) 
        mark_only_lora_as_trainable(model)

I tried this way, but I encountered an error about loading the state dict:

Selection_283

I also tried to put load_state_dict() out from the fabric.sharded_model(), but the issue is the same.

Can you provide me with some hints or code references? Thanks!

awaelchli commented 1 year ago

Hmm yes I see. A bit more work is needed here to be able to load the checkpoint in to a deepspeed sharded model. Ideally we would use fabric.load() here but for this the checkpoint would have to be a deepspeed checkpoint. I need to think how we could detect and properly load that.

timothylimyl commented 1 year ago

I am facing the same issue for lora with DeepSpeed, a bunch of size mismatch errors.

HeorhiiS commented 1 year ago

Facing same issue. Should there be a conversion to deepspeed checkpoint from the existing LLAMA checkpoint?

Thanks for your reply! Since I'm a beginner at using fabric and deepspeed, I'm not sure how to add this context manager correctly. Taking lit-llama as an example, should I write like this?

    with fabric.sharded_model():
        with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
            model = LLaMA(config)

        checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

        # strict=False because missing keys due to LoRA weights not contained in checkpoint state
        model.load_state_dict(checkpoint, strict=False) 
        mark_only_lora_as_trainable(model)

I tried this way, but I encountered an error about loading the state dict:

Selection_283

I also tried to put load_state_dict() out from the fabric.sharded_model(), but the issue is the same.

Can you provide me with some hints or code references? Thanks!

alexgshaw commented 1 year ago

Any updates on this?

scvance commented 1 year ago

I was able to get the model to run by first converting the weights to deepspeed checkpoints, and then loading the model from those checkpoints.

I set deepspeed strategy as follows

deep_off = DeepSpeedStrategy(config="deep_config.json")

This was the config I used

{
    "bf16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
} 

I then started fabric with the following

fabric = L.Fabric(
        accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=deep_off
    ) 

Then I loaded the checkpoints as follows

checkpoint_paths = [
        "zero_pp_rank_0_mp_rank_00_model_states.pt",
        "zero_pp_rank_1_mp_rank_00_model_states.pt",
        "zero_pp_rank_2_mp_rank_00_model_states.pt",
        "zero_pp_rank_3_mp_rank_00_model_states.pt",
        "zero_pp_rank_4_mp_rank_00_model_states.pt",
        "zero_pp_rank_5_mp_rank_00_model_states.pt",
        "zero_pp_rank_6_mp_rank_00_model_states.pt",
        "zero_pp_rank_7_mp_rank_00_model_states.pt",
    ]
    merged_checkpoint = {}
    for checkpoint_path in checkpoint_paths:
        match = re.search(r"rank_(\d+)", checkpoint_path)
        rank_num = int(match.group(1))
        if fabric.global_rank == rank_num:
            checkpoint = torch.load(checkpoint_path)
            checkpoint = {k: v for k, v in checkpoint.items() if v is not None}
            for key, value in checkpoint.items():
                if key not in merged_checkpoint:
                    merged_checkpoint[key] = value
                else:
                    try:
                        merged_checkpoint[key] += value
                    except TypeError:
                        merged_checkpoint[key].update(value)
    checkpoint = merged_checkpoint

    # with fabric.device:
    with fabric.init_module():
        torch.set_default_tensor_type(torch.HalfTensor)
        model = LLaMA(config).bfloat16()
        torch.set_default_tensor_type(torch.FloatTensor)
        model.load_state_dict(checkpoint, strict=False)

    optimizer = DeepSpeedCPUAdam(model.parameters(), lr=learning_rate)
    model, optimizer = fabric.setup(model, optimizer)
    train(fabric, model, optimizer, train_data, val_data, out_dir) 

And then you have to comment out the following line in the train function because it doesn't work with deepspeed

# with fabric.no_backward_sync(model, enabled=is_accumulating):

this should work, but I'm sure there is a better way to do it.

HeorhiiS commented 1 year ago

@scvance I'll check it out, was it a full model checkpoint or a LoRA one?

scvance commented 1 year ago

@HeorhiiS It was a full 7B model. Note that it trained slower than the normal model.

WilliamGazeley commented 11 months ago

Have there been any updates on this? I'm also looking at how to use DeepSpeed properly (with Mistral 7B in my case), but can't seem to find examples of usage with fabric.

WilliamGazeley commented 11 months ago

@scvance Any chance you could upload the full script you used to make this work?

qgzang commented 3 months ago

mark