YuanGongND / ltu

Code, Dataset, and Pretrained Models for Audio and Speech Large Language Model "Listen, Think, and Understand".
389 stars 36 forks source link

Question about model loading in inference #22

Open dingdongwang opened 9 months ago

dingdongwang commented 9 months ago

Hi, I have another question about the model related configuration settings during batch inference after model fine tuning.

In the inference_batch.py script for LTU-AS provided below:

def main(
    load_8bit: bool = False,
    base_model: str = "../../pretrained_mdls/vicuna_ltuas/",
    prompt_template: str = "alpaca_short"
):
    eval_mdl_path = '../../pretrained_mdls/ltuas_long_noqa_a6.bin'

Should I update the eval_mdl_path to the path of my fine-tuned checkpoint to run batch inference, while keeping the base_model: str = "../../pretrained_mdls/vicuna_ltuas/" unchanged? Is my understanding correct?

Besides, if the following code means to merge the components of the LLM with the fine-tuned checkpoint (in another words, does the following code segment integrate the parts of the LLM with the fine-tuned parameters together as a new model?)

    if eval_mdl_path != 'vicuna':
        state_dict = torch.load(eval_mdl_path, map_location='cpu')
        miss, unexpect = model.load_state_dict(state_dict, strict=False)
        print('unexpect', unexpect)

Thank you so much for your patience and detailed responses all the time!

YuanGongND commented 9 months ago

Should I update the eval_mdl_path to the path of my fine-tuned checkpoint to run batch inference, while keeping the base_model: str = "../../pretrained_mdls/vicuna_ltuas/" unchanged? Is my understanding correct?

Besides, if the following code means to merge the components of the LLM with the fine-tuned checkpoint (in another words, does the following code segment integrate the parts of the LLM with the fine-tuned parameters together as a new model?)

This is correct. str = "../../pretrained_mdls/vicuna_ltuas/" is to load LLM and pretrained audio encoder, which actually happens here:

https://github.com/YuanGongND/ltu/blob/8c8f92446a8121fc78d2f7dece2a6e08dc2061b2/src/ltu_as/inference_batch.py#L68

The state_dict = torch.load(eval_mdl_path, map_location='cpu') is to further load finetuned audio encoder, linear projection layer, and lora adapters.

print('unexpect', unexpect)

This should not print anything.

If you do not fully understand how HF loading working, I recommend to keep this loading scheme. HF loading is not just loading a checkpoint.

-Yuan

dingdongwang commented 9 months ago

Got it, thank you so much!