Open shidingz opened 1 month ago
我发现llama3的模板有些问题,如果设计多轮对话会出现 WARNING: tokenization mismatch 在 def preprocess_llama_3( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: 这个函数里这部分代码是不是不太对
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
print(f"WARNING: parts!=: {parts}")
break
parts[0] += sep
# include <bos> for all rounds
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer)) - 1
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids) - 1
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# include <|eot_id|> for all rounds
round_len += 1
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
模板并没有添加bos token,所以为什么要设置 cur_len = 1 target[:cur_len] = IGNORE_INDEX 然后这里的 round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 -1 -2分别是为什么呢 按照我对这个模板的理解应该修改成这样吧
cur_len = 0
# target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
print(f"WARNING: parts!=: {parts}")
break
parts[0] += sep
# not include <bos> for all rounds
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids)
# include <|eot_id|> for all rounds
round_len += 1
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
When I run this script-scripts/llama3/train/stage_2_full_v8b_672_hr_1536.sh, I encounter this error- WARNING: tokenization mismatch: 156 vs. 161. (ignored)