mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.83k stars 502 forks source link

Composer lora weights conversion #1325

Closed zhao-lun closed 2 days ago

zhao-lun commented 3 days ago

env

composer: 0.23.4
llm-foundary: v0.9.0

example config with lora

  .... # mpt-7b-arc-easy--gpu.yaml
  peft_config:
    r: 64
    peft_type: LORA
    task_type: CAUSAL_LM
    lora_alpha: 128
    lora_dropout: 0.05
    target_modules:
      - Wqkv
   ....

fine tuning

composer train.py finetune_example/mpt-7b-arc-easy--gpu.yaml

after fine-tuning

 ls 
 50M Jul  1 08:27 ep1-ba12-rank0.pt
 17 Jul  1 08:27 latest-rank0.pt -> ep1-ba12-rank0.pt

Inference

Attempt to convert weight to hf/peft compactible

python3 inference/convert_composer_to_hf.py --composer_path test/checkpoints/latest-rank0.pt --hf_output_path test_inf/ 

.....
  File "/llm-foundry/llmfoundry/utils/registry_utils.py", line 161, in construct_from_registry
    constructed_item = registered_constructor(**kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: MultiheadAttention.__init__() got an unexpected keyword argument 'prefix_lm'

appreciate any advice/pointers to convert it

Expectation

able to do inference with adapters/merged weights

from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b")
peft_model_id = "xx"
model = PeftModel.from_pretrained(base_model, peft_model_id)
dakinggg commented 2 days ago

I suspect it will work if you use a LLM Foundry commit from before we removed prefix LM from the code in LLM Foundry (https://github.com/mosaicml/llm-foundry/pull/1065). Could you give that a try?

dakinggg commented 2 days ago

More generally, I recommend adding the hf_checkpointer to your callbacks so that hf checkpoints are produced during training instead of trying to convert after the fact.

zhao-lun commented 2 days ago

@dakinggg Thanks a lot!

adding the following section allow adapter weight generation.

callbacks:
  hf_checkpointer: 
    save_folder: ./{run_name}/checkpoints
    save_interval: "1ep"
$ls
README.md  adapter_config.json  adapter_model.safetensors  special_tokens_map.json  tokenizer.json  tokenizer_config.json