foundation-model-stack / fms-hf-tuning

🚀 Collection of tuning recipes with HuggingFace SFTTrainer and PyTorch FSDP.
Apache License 2.0
28 stars 48 forks source link

fix: remove lm_head post processing #333

Closed Abhishek-TAMU closed 2 months ago

Abhishek-TAMU commented 2 months ago

Description of the change

Removal of lm_head hack which was made to fix lm_head issue and now fixed with newer vllm versions, the change coming in as of v0.5.4

Related issue number

#1166

How to verify the PR

Running LoRA and full fine tuning of granite-3b and llama-8b model without removal of lm_head able to run inference on.

Was the PR tested

anhuong commented 2 months ago

After testing, found that accelerate version is not working as expected.

New logic intorduced in get_state_dict, also removes the top-level FSDP wrapper from the model. So then since FSDP keeps flattened params, all the parameters managed by the top-level wrapper will now remained flattened when model.state_dict is called. The other child FSDP wrappers will protect their parameters, since when the state_dict call recurses to them, they will use the FSDP version of state_dict to unwrap the wrappers.

This results in error:

size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([62915840]) from checkpoint, the shape in current model is torch.Size([49152, 2560]).
size mismatch for model.norm.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([2560]).