Open KzZheng opened 1 year ago
For zero-3 with deepspeed, you should add the context manager over the model initialization:
with fabric.sharded_model():
model = ...
Perhaps you forgot this?
Thanks for your reply! Since I'm a beginner at using fabric and deepspeed, I'm not sure how to add this context manager correctly. Taking lit-llama as an example, should I write like this?
with fabric.sharded_model():
with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
model = LLaMA(config)
checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)
I tried this way, but I encountered an error about loading the state dict:
I also tried to put load_state_dict() out from the fabric.sharded_model(), but the issue is the same.
Can you provide me with some hints or code references? Thanks!
Hmm yes I see. A bit more work is needed here to be able to load the checkpoint in to a deepspeed sharded model. Ideally we would use fabric.load()
here but for this the checkpoint would have to be a deepspeed checkpoint. I need to think how we could detect and properly load that.
I am facing the same issue for lora with DeepSpeed, a bunch of size mismatch errors.
Facing same issue. Should there be a conversion to deepspeed checkpoint from the existing LLAMA checkpoint?
Thanks for your reply! Since I'm a beginner at using fabric and deepspeed, I'm not sure how to add this context manager correctly. Taking lit-llama as an example, should I write like this?
with fabric.sharded_model(): with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True): model = LLaMA(config) checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth") # strict=False because missing keys due to LoRA weights not contained in checkpoint state model.load_state_dict(checkpoint, strict=False) mark_only_lora_as_trainable(model)
I tried this way, but I encountered an error about loading the state dict:
I also tried to put load_state_dict() out from the fabric.sharded_model(), but the issue is the same.
Can you provide me with some hints or code references? Thanks!
Any updates on this?
I was able to get the model to run by first converting the weights to deepspeed checkpoints, and then loading the model from those checkpoints.
I set deepspeed strategy as follows
deep_off = DeepSpeedStrategy(config="deep_config.json")
This was the config I used
{
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
I then started fabric with the following
fabric = L.Fabric(
accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=deep_off
)
Then I loaded the checkpoints as follows
checkpoint_paths = [
"zero_pp_rank_0_mp_rank_00_model_states.pt",
"zero_pp_rank_1_mp_rank_00_model_states.pt",
"zero_pp_rank_2_mp_rank_00_model_states.pt",
"zero_pp_rank_3_mp_rank_00_model_states.pt",
"zero_pp_rank_4_mp_rank_00_model_states.pt",
"zero_pp_rank_5_mp_rank_00_model_states.pt",
"zero_pp_rank_6_mp_rank_00_model_states.pt",
"zero_pp_rank_7_mp_rank_00_model_states.pt",
]
merged_checkpoint = {}
for checkpoint_path in checkpoint_paths:
match = re.search(r"rank_(\d+)", checkpoint_path)
rank_num = int(match.group(1))
if fabric.global_rank == rank_num:
checkpoint = torch.load(checkpoint_path)
checkpoint = {k: v for k, v in checkpoint.items() if v is not None}
for key, value in checkpoint.items():
if key not in merged_checkpoint:
merged_checkpoint[key] = value
else:
try:
merged_checkpoint[key] += value
except TypeError:
merged_checkpoint[key].update(value)
checkpoint = merged_checkpoint
# with fabric.device:
with fabric.init_module():
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMA(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
optimizer = DeepSpeedCPUAdam(model.parameters(), lr=learning_rate)
model, optimizer = fabric.setup(model, optimizer)
train(fabric, model, optimizer, train_data, val_data, out_dir)
And then you have to comment out the following line in the train function because it doesn't work with deepspeed
# with fabric.no_backward_sync(model, enabled=is_accumulating):
this should work, but I'm sure there is a better way to do it.
@scvance I'll check it out, was it a full model checkpoint or a LoRA one?
@HeorhiiS It was a full 7B model. Note that it trained slower than the normal model.
Have there been any updates on this? I'm also looking at how to use DeepSpeed properly (with Mistral 7B in my case), but can't seem to find examples of usage with fabric.
@scvance Any chance you could upload the full script you used to make this work?
mark
Hi, I wonder how to write the code for using the deepspeed zero-3-offload strategy correctly. Currently, my code looks like:
However, it seems the parameters are duplicated for all gpu. I attached the screenshot to show the GPU utilization after
model, optimizer = fabric.setup(model, optimizer)
:According to my understanding, the parameters should be distributed on different devices, right?