CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

Question about saving peft checkpoint #565

Open nhanph opened 1 year ago

nhanph commented 1 year ago

🐛 Describe the bug

From my understand, when saving checkpoints for peft models (see here), trlx removes pytorch_model.bin before calling save_pretrained which makes the removal useless in my opinion.

Is this intentional or we should move the removal code after save_pretrained is called?

Here is an example of a directory resulting from save_pretrained:

adapter_config.json  adapter_model.bin  optimizer.bin  pytorch_model.bin  random_states_0.pkl  special_tokens_map.json  spiece.model  tokenizer_config.json  tokenizer.json

Which trlX version are you using?

0.7.0

Additional system and package information

3.10.12

maxreciprocate commented 1 year ago

Hello @nhanph!

No, the removal is not useless. If you check contents of python_model.bin immediately after this line: https://github.com/CarperAI/trlx/blob/bcd237f1e94c84c5c9f5a4086bab34c0946e3fa7/trlx/trainer/accelerate_base_trainer.py#L309 you will see that it contains the whole model state dictionary, which is not needed. After deleting it and recreating it with model.save_pretrained with heads_only=True, only value heads will be kept there.

>>> list(before_deletion_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias', 'base_model.model.transformer.wte.weight', 'base_model.model.transformer.wpe.weight', 'base_model.model.transformer.h.0.ln_1.weight', 'base_model.model.transformer.h.0.ln_1.bias', 'base_model.model.transformer.h.0.attn.c_attn.weight', 'base_model.model.transformer.h.0.attn.c_attn.bias', 'base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.0.attn.c_proj.weight', 'base_model.model.transformer.h.0.attn.c_proj.bias', 'base_model.model.transformer.h.0.ln_2.weight', 'base_model.model.transformer.h.0.ln_2.bias', 'base_model.model.transformer.h.0.mlp.c_fc.weight', 'base_model.model.transformer.h.0.mlp.c_fc.bias', 'base_model.model.transformer.h.0.mlp.c_proj.weight', 'base_model.model.transformer.h.0.mlp.c_proj.bias', 'base_model.model.transformer.h.1.ln_1.weight', 'base_model.model.transformer.h.1.ln_1.bias', 'base_model.model.transformer.h.1.attn.c_attn.weight', 'base_model.model.transformer.h.1.attn.c_attn.bias', 'base_model.model.transformer.h.1.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.1.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.1.attn.c_proj.weight', 'base_model.model.transformer.h.1.attn.c_proj.bias', 'base_model.model.transformer.h.1.ln_2.weight', 'base_model.model.transformer.h.1.ln_2.bias', 'base_model.model.transformer.h.1.mlp.c_fc.weight', 'base_model.model.transformer.h.1.mlp.c_fc.bias']
>>> list(after_save_pretrained_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias']
nhanph commented 1 year ago

Thank you @maxreciprocate , I got the point about saving model's value head now.

My original question is from my observation when running ILQL training script that I see a pytorch_model.bin with the size comparable with the original model so I suspect that the base model is also saved. Is there somewhere that the heads_only flag is set to true during checkpointing when using peft_config as I cannot find it set anywhere?