Open fangzhouli opened 3 weeks ago
edit: wrong answer
Hey @fangzhouli, good question! I havent had a chance to use the generation script, but I noticed that you are trying to instantiate the llama3_8b model, and not the lora version. Please take a look at the config you used to train your model and make sure that generate instantiates the same model. You will see that the model probably has a different name and has some different sections: https://github.com/pytorch/torchtune/blob/e568b671bf4dd043649611c3d0905967257ccac2/recipes/configs/llama3_1/8B_lora.yaml#L27
for the python script, you would probably have to do something like:
model = lora_llama3_1_8b(lora_attn_modules = ['q_proj', 'v_proj'], ***other_lora_args)
Hi @felipemello1, thank you for your reply! I was originally instantiating torchtune.models.llama3.lora_llama3_8b
for the generation. That resulted in both CLI generation script and Python script failed with the same runtime error with loading state_dict.
I am following this post (https://github.com/pytorch/torchtune/issues/1188) and am under the impression that after fine-tuning using torchtune.models.llama3.lora_llama3_8b
, the output meta_model_0.pt
already merged both original and adapter weights, and thus I should directly use the torchtune.models.llama3.llama3_8b
for inference.
My current workaround is to call the InferenceRecipe constructor to load my model by reading my config.yaml
since it worked. But I would like to know if there is a less dirty way to do it.
from torchtune import config
from torchtune.config._utils import _merge_yaml_and_cli_args
from torchtune.utils.argparse import TuneRecipeArgumentParser
class InferenceRecipe:
# Only overwrite the `generate` function.
def generate(self, cfg: DictConfig, prompts: List[str]) -> List[str]:
...
yaml_args, cli_args = TuneRecipeArgumentParser().parse_known_args(
['--config', '/to/my/custom_generation_config.yaml'],
)
cfg = _merge_yaml_and_cli_args(yaml_args, cli_args)
config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
recipe = InferenceRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
prompts = [...] # Multiple prompts.
outputs = recipe.generate(cfg=cfg, prompts=prompts)
Oh, i see! @RdoubleA , do you mind taking a look when you have a chance?
@fangzhouli thanks for creating the issue. Can you try using the MetaCheckpointer instead of the TorchTuneCheckpointer in your script?
The way the checkpointers work is that they ensure the input format and the output format line up whenever you use them. So since you used the MetaCheckpointer in your fine-tuning (if I understand correctly), your fine-tuned checkpoint will also be in the Meta format (hence why it’s saved as meta_model_0.pt). So you should use this same checkpointer in your generation script.
Also, your usage of llama3_8b instead of lora_llama3_8b is correct here: the fine-tuning script will merge the weights back into the original model before saving so you don’t need to use a LoRA model builder for generation.
Sorry for any confusion here! I think we can make this a bit clearer in our documentation, especially the bit about which checkpointer to use and why.
I have a finetuned LoRA-Llama3-8b model. Since I have many prompts, I would like to write a script to generate outputs for all prompts without repeatedly loading the model using the CLI script.
The following
generation_config.yaml
script for a single prompt worked for me:Here is my Python script for loading the state_dict:
This raises the following error: