Open nhanph opened 1 year ago
Hello @nhanph!
No, the removal is not useless. If you check contents of python_model.bin
immediately after this line: https://github.com/CarperAI/trlx/blob/bcd237f1e94c84c5c9f5a4086bab34c0946e3fa7/trlx/trainer/accelerate_base_trainer.py#L309 you will see that it contains the whole model state dictionary, which is not needed. After deleting it and recreating it with model.save_pretrained
with heads_only=True
, only value heads will be kept there.
>>> list(before_deletion_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias', 'base_model.model.transformer.wte.weight', 'base_model.model.transformer.wpe.weight', 'base_model.model.transformer.h.0.ln_1.weight', 'base_model.model.transformer.h.0.ln_1.bias', 'base_model.model.transformer.h.0.attn.c_attn.weight', 'base_model.model.transformer.h.0.attn.c_attn.bias', 'base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.0.attn.c_proj.weight', 'base_model.model.transformer.h.0.attn.c_proj.bias', 'base_model.model.transformer.h.0.ln_2.weight', 'base_model.model.transformer.h.0.ln_2.bias', 'base_model.model.transformer.h.0.mlp.c_fc.weight', 'base_model.model.transformer.h.0.mlp.c_fc.bias', 'base_model.model.transformer.h.0.mlp.c_proj.weight', 'base_model.model.transformer.h.0.mlp.c_proj.bias', 'base_model.model.transformer.h.1.ln_1.weight', 'base_model.model.transformer.h.1.ln_1.bias', 'base_model.model.transformer.h.1.attn.c_attn.weight', 'base_model.model.transformer.h.1.attn.c_attn.bias', 'base_model.model.transformer.h.1.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.1.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.1.attn.c_proj.weight', 'base_model.model.transformer.h.1.attn.c_proj.bias', 'base_model.model.transformer.h.1.ln_2.weight', 'base_model.model.transformer.h.1.ln_2.bias', 'base_model.model.transformer.h.1.mlp.c_fc.weight', 'base_model.model.transformer.h.1.mlp.c_fc.bias']
>>> list(after_save_pretrained_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias']
Thank you @maxreciprocate , I got the point about saving model's value head now.
My original question is from my observation when running ILQL training script that I see a pytorch_model.bin
with the size comparable with the original model so I suspect that the base model is also saved. Is there somewhere that the heads_only
flag is set to true during checkpointing when using peft_config
as I cannot find it set anywhere?
🐛 Describe the bug
From my understand, when saving checkpoints for peft models (see here), trlx removes
pytorch_model.bin
before callingsave_pretrained
which makes the removal useless in my opinion.Is this intentional or we should move the removal code after
save_pretrained
is called?Here is an example of a directory resulting from
save_pretrained
:Which trlX version are you using?
0.7.0
Additional system and package information
3.10.12