huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.73k stars 26.94k forks source link

Why do the implementation behaviors of official llava and transformers differ? #30415

Closed bleedingfight closed 6 months ago

bleedingfight commented 6 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

MyModel is llava from official:vicuna+clip+mmprojector。 convert script:

import argparse                                                                                                                                                          
import os                                                                                                                                                                

import torch                                                                                                                                                             
from huggingface_hub import hf_hub_download                                                                                                                              

from transformers import (                                                                                                                                               
    AddedToken,                                                                                                                                                          
    AutoConfig,                                                                                                                                                          
    AutoTokenizer,                                                                                                                                                       
    CLIPImageProcessor,                                                                                                                                                  
    LlavaConfig,                                                                                                                                                         
    LlavaForConditionalGeneration,                                                                                                                                       
    LlavaProcessor,                                                                                                                                                      
)                                                                                                                                                                        

KEYS_TO_MODIFY_MAPPING = {                                                                                                                                               
    "model.vision_tower.": "",                                                                                                                                           
    "model.mm_projector": "multi_modal_projector",                                                                                                                       
    "model": "model.model",                                                                                                                                              
    "vision_model.model": "vision_model",                                                                                                                                
    "lm_head": "language_model.lm_head",                                                                                                                                 
    "model.model": "language_model.model",                                                                                                                               
    "multi_modal_projector.0": "multi_modal_projector.linear_1",                                                                                                         
    "multi_modal_projector.2": "multi_modal_projector.linear_2",                                                                                                         
}                                                                                                                                                                        

def convert_state_dict_to_hf(state_dict):                                                                                                                                
    new_state_dict = {}                                                                                                                                                  
    for key, value in state_dict.items():                                                                                                                                
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():                                                                                                    
            if key_to_modify in key:                                                                                                                                     
                key = key.replace(key_to_modify, new_key)                                                                                                                

        new_state_dict[key] = value                                                                                                                                      
    return new_state_dict
def convert_llava_llama_to_hf(                                                                                                                                           
    text_model_id, vision_model_id, output_hub_path, old_state_dict_id                                                                                                   
):
    torch.set_default_dtype(torch.float16) 
    text_config = AutoConfig.from_pretrained(text_model_id)

    tokenizer = AutoTokenizer.from_pretrained(text_model_id)
    tokenizer.add_tokens(
        AddedToken("<image>", special=True, normalized=False), special_tokens=True
    )
    tokenizer.add_special_tokens({"pad_token": "<pad>"})

    image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)

    processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) 

    # vision_config = CLIPVisionConfig.from_pretrained(vision_model_id)

    config = LlavaConfig(text_config=text_config)
    # config.pad_token_id = 32001

    model = LlavaForConditionalGeneration(config)

    # Pad to 64 for performance reasons
    pad_shape = 64

    state_dict_path = os.path.join(old_state_dict_id, "model_state_dict.bin")
    if not os.path.exists(state_dict_path):
        state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin") 

    state_dict = torch.load(state_dict_path, map_location="cpu")
    state_dict = convert_state_dict_to_hf(state_dict)
    # 替换LLM模型的权重
    model.load_state_dict(state_dict, strict=True, assign=True)

    pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
    mu = torch.mean(pre_expansion_embeddings, dim=0).float()
    n = pre_expansion_embeddings.size()[0] 
    sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
    dist = torch.distributions.multivariate_normal.MultivariateNormal(
        mu, covariance_matrix=1e-5 * sigma 
    )

    # We add an image token so we resize the model
    model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
    model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
        tuple(
            (
                dist.sample()
                for _ in range(
                    model.language_model.model.embed_tokens.weight.data[32000:].shape[0]
                )
            )
        ),
        dim=0,
    )
    model.language_model.lm_head.weight.data[32000:] = torch.stack(
        tuple(
            (
                dist.sample()
                for _ in range(
                    model.language_model.lm_head.weight.data[32000:].shape[0]
                )
            )
        ),
        dim=0,
    )
    is_local = os.environ.get("USE_LOCAL", True)
    if isinstance(is_local, str):
        is_local = eval(is_local)
    if not is_local:
        model.push_to_hub(output_hub_path) 
        processor.push_to_hub(output_hub_path)
    else:
        model.save_pretrained(output_hub_path)
        processor.save_pretrained(output_hub_path)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--text_model_id",
        help="Hub location of the text model",
    )
    parser.add_argument(
        "--vision_model_id",
        help="Hub location of the vision model",
    )
    parser.add_argument(
        "--output_hub_path",
        help="Location on the hub of the converted model",
    )
    parser.add_argument(
        "--old_state_dict_id",
        help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
    )
    args = parser.parse_args()
    convert_llava_llama_to_hf(
        args.text_model_id,
        args.vision_model_id,
        args.output_hub_path,
        args.old_state_dict_id,
    )

if __name__ == "__main__":
    main()

I only modified a little bit of the code so that I could save the model directly locally instead of pushing it to the hf repository

The code may seem a bit long, but it's simple - simply load the model and use the generate method to output the result. USETRANSFORMER is used to use different types of reasoning.llava official out:

B


I tried debugging the code and found that the Transformers process output looks like this:
prompt-->token----->--greedy_search--language_model->one_token,two_token ---> [token,one_token,two_token]-->decode
                                        |<-----------------------------------------------------|               |
                                        |<-------------------------------------------------------------------|
image-->embedding--->
llava official:
prompt-->token-->--greedy_search--->one_token,two_token,three_token ---> [one_token,two_token]-->decode
                                        |<----------------------------|               |                  |
                                        |<------------------------------------------|                  |
                                        |<-----------------------------------------------------------|

image-->embedding--->
this is tranformers:
![Screenshot_20240423_174550](https://github.com/huggingface/transformers/assets/11495161/50ab6265-e8c7-4d04-9e16-baf291bb7237)
llava official:
![Screenshot_20240423_175715](https://github.com/huggingface/transformers/assets/11495161/fa0ade4d-15dc-4398-b06d-bde62689e07a)

input_ids as input into model,but transformers input_ids is prompt's tokenize,so output is include prompt.

### Expected behavior

transformers:"chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER:  \n下面的
文章描述了一个实验。阅读文章,然后按照以下说明进行操作。\n\nMadelyn在雪板的底部涂上了一层薄蜡,然后直接下坡滑行。然后,她去掉了蜡,再次直接下坡滑行。她重复了这个过程四次,每次都交替使用薄蜡或不使用薄蜡滑行。她的朋友Tucker计时每次滑行的时间。Madelyn和Tucker计算了使用薄蜡滑行和不使用薄蜡滑行时直接下坡所需的平均时间。\n图:滑雪板下坡。\n麦德琳和塔克的实验能最好回答哪个问题?\nA. 当麦德琳的雪板上有一层薄蜡或一层厚蜡时,它是否能在较短的时间内滑下山坡?\nB. 当麦德琳的雪板上有一层蜡或没有蜡时,它是否能在较短
的时间内滑下山坡?\n请直接回答选项字母。 ASSISTANT: B"
llava:B
amyeroberts commented 6 months ago

cc @younesbelkada

zucchini-nlp commented 6 months ago

@bleedingfight hey!

Yes, the transformers currently returns prompt+generated_text as output for generative models. If you want only the generates part you can:

out = model.generate(**inputs)
out_wo_prompt = out[ : , inputs.input_ids.shape[-1] : ]
print(tokenizer.batch_decode(out_wo_prompt, skip_special_tokens=True))
bleedingfight commented 6 months ago

@zucchini-nlp ok,thanks