dvlab-research / MGM

Official repo for "Mini-Gemini: Mining the Potential of Multi-modality Vision Language Models"
Apache License 2.0
3.1k stars 277 forks source link

I get this error: WARNING: tokenization mismatch: 156 vs. 161. (ignored) when I finetune llama3 #126

Open shidingz opened 1 month ago

shidingz commented 1 month ago

When I run this script-scripts/llama3/train/stage_2_full_v8b_672_hr_1536.sh, I encounter this error- WARNING: tokenization mismatch: 156 vs. 161. (ignored)

shidingz commented 1 month ago

我发现llama3的模板有些问题,如果设计多轮对话会出现 WARNING: tokenization mismatch 在 def preprocess_llama_3( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: 这个函数里这部分代码是不是不太对

include for all rounds

    cur_len = 1
    target[:cur_len] = IGNORE_INDEX
    for i, rou in enumerate(re_rounds):
        if rou == "":
            break

        parts = rou.split(sep)
        if len(parts) != 2:
            print(f"WARNING: parts!=: {parts}")
            break
        parts[0] += sep

        # include <bos> for all rounds
        if has_image:
            round_len = len(tokenizer_image_token(rou, tokenizer)) - 1
            instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
        else:
            round_len = len(tokenizer(rou).input_ids) - 1
            instruction_len = len(tokenizer(parts[0]).input_ids) - 2

        # include <|eot_id|> for all rounds
        round_len += 1
        instruction_len += 1

        target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
        cur_len += round_len

    target[cur_len:] = IGNORE_INDEX

模板并没有添加bos token,所以为什么要设置 cur_len = 1 target[:cur_len] = IGNORE_INDEX 然后这里的 round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 -1 -2分别是为什么呢 按照我对这个模板的理解应该修改成这样吧

not include for all rounds

    cur_len = 0
    # target[:cur_len] = IGNORE_INDEX
    for i, rou in enumerate(re_rounds):
        if rou == "":
            break

        parts = rou.split(sep)
        if len(parts) != 2:
            print(f"WARNING: parts!=: {parts}")
            break
        parts[0] += sep

        # not include <bos> for all rounds
        if has_image:
            round_len = len(tokenizer_image_token(rou, tokenizer))
            instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
        else:
            round_len = len(tokenizer(rou).input_ids)
            instruction_len = len(tokenizer(parts[0]).input_ids)

        # include <|eot_id|> for all rounds
        round_len += 1
        instruction_len += 1

        target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
        cur_len += round_len

    target[cur_len:] = IGNORE_INDEX