pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.92k stars 354 forks source link

Error for inference LoRA Llama3-8b in Python Script #1387

Open fangzhouli opened 3 weeks ago

fangzhouli commented 3 weeks ago

I have a finetuned LoRA-Llama3-8b model. Since I have many prompts, I would like to write a script to generate outputs for all prompts without repeatedly loading the model using the CLI script.

The following generation_config.yaml script for a single prompt worked for me:

model:
  _component_: torchtune.models.llama3.llama3_8b

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /to/my/tokenizer.model

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /to/my/checkpoints/
  checkpoint_files: [
    meta_model_0.pt,
  ]
  output_dir: /to/my/checkpoints/
  model_type: LLAMA3
...
prompt: 'Tell me a joke.'
...

Here is my Python script for loading the state_dict:

from torchtune.utils import FullModelTorchTuneCheckpointer
from torchtune.models.llama3 import llama3_8b

model = llama3_8b()
checkpointer = FullModelTorchTuneCheckpointer(
    checkpoint_dir="/to/my/checkpoints",
    checkpoint_files=["meta_model_0.pt"],
    output_dir="/to/my/checkpoints",
    model_type='LLAMA3',
)
ckpt_dict = checkpointer.load_checkpoint()
model_state_dict = ckpt_dict['model']
model.load_state_dict(model_state_dict)

This raises the following error:

RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
        Missing key(s) in state_dict: "layers.0.sa_norm.scale", "layers.0.attn.q_proj.weight", "layers.0.attn.k_proj.weight", "layers.0.attn.v_proj.weight", "layers.0.attn.output_proj.weight", "layers.0.mlp_norm.scale", "layers.0.mlp.w1.weight", "layers.0.mlp.w2.weight", "layers.0.mlp.w3.weight", "layers.1.sa_norm.scale", "layers.1.attn.q_proj.weight", "layers.1.attn.k_proj.weight", "layers.1.attn.v_proj.weight", "layers.1.attn.output_proj.weight", ...
        Unexpected key(s) in state_dict: "layers.0.attention_norm.weight", "layers.0.attention.wq.weight", "layers.0.attention.wk.weight", "layers.0.attention.wv.weight", ...
felipemello1 commented 3 weeks ago

edit: wrong answer

Hey @fangzhouli, good question! I havent had a chance to use the generation script, but I noticed that you are trying to instantiate the llama3_8b model, and not the lora version. Please take a look at the config you used to train your model and make sure that generate instantiates the same model. You will see that the model probably has a different name and has some different sections: https://github.com/pytorch/torchtune/blob/e568b671bf4dd043649611c3d0905967257ccac2/recipes/configs/llama3_1/8B_lora.yaml#L27

for the python script, you would probably have to do something like: model = lora_llama3_1_8b(lora_attn_modules = ['q_proj', 'v_proj'], ***other_lora_args)

fangzhouli commented 3 weeks ago

Hi @felipemello1, thank you for your reply! I was originally instantiating torchtune.models.llama3.lora_llama3_8b for the generation. That resulted in both CLI generation script and Python script failed with the same runtime error with loading state_dict.

I am following this post (https://github.com/pytorch/torchtune/issues/1188) and am under the impression that after fine-tuning using torchtune.models.llama3.lora_llama3_8b, the output meta_model_0.pt already merged both original and adapter weights, and thus I should directly use the torchtune.models.llama3.llama3_8b for inference.

My current workaround is to call the InferenceRecipe constructor to load my model by reading my config.yaml since it worked. But I would like to know if there is a less dirty way to do it.

from torchtune import config
from torchtune.config._utils import _merge_yaml_and_cli_args
from torchtune.utils.argparse import TuneRecipeArgumentParser

class InferenceRecipe:
    # Only overwrite the `generate` function.
    def generate(self, cfg: DictConfig, prompts: List[str]) -> List[str]:
        ...

yaml_args, cli_args = TuneRecipeArgumentParser().parse_known_args(
    ['--config', '/to/my/custom_generation_config.yaml'],
)
cfg = _merge_yaml_and_cli_args(yaml_args, cli_args)
config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
recipe = InferenceRecipe(cfg=cfg)
recipe.setup(cfg=cfg)

prompts = [...]  # Multiple prompts.
outputs = recipe.generate(cfg=cfg, prompts=prompts)
felipemello1 commented 3 weeks ago

Oh, i see! @RdoubleA , do you mind taking a look when you have a chance?

ebsmothers commented 3 weeks ago

@fangzhouli thanks for creating the issue. Can you try using the MetaCheckpointer instead of the TorchTuneCheckpointer in your script?

The way the checkpointers work is that they ensure the input format and the output format line up whenever you use them. So since you used the MetaCheckpointer in your fine-tuning (if I understand correctly), your fine-tuned checkpoint will also be in the Meta format (hence why it’s saved as meta_model_0.pt). So you should use this same checkpointer in your generation script.

Also, your usage of llama3_8b instead of lora_llama3_8b is correct here: the fine-tuning script will merge the weights back into the original model before saving so you don’t need to use a LoRA model builder for generation.

Sorry for any confusion here! I think we can make this a bit clearer in our documentation, especially the bit about which checkpointer to use and why.