QwenLM / CodeQwen1.5

CodeQwen1.5 is the code version of Qwen, the large language model series developed by Qwen team, Alibaba Cloud.
371 stars 22 forks source link

[fim] Extra space at the beginning #52

Closed ytyt-yt closed 2 months ago

ytyt-yt commented 2 months ago

感觉FIM的时候经常会出现开头多余一个空格的情况。比如:

from transformers import AutoTokenizer, AutoModelForCausalLM
# load model
device = "cuda" # the device to load the model onto

tokenizer = AutoTokenizer.from_pretrained("Qwen/CodeQwen1.5-7B")
model = AutoModelForCausalLM.from_pretrained("Qwen/CodeQwen1.5-7B", device_map="auto").eval()

input_text = """<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
<fim_suffix>
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>"""

model_inputs = tokenizer([input_text], return_tensors="pt").to(device)

# Use `max_new_tokens` to control the maximum output length.
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=False)[0]
# The generated_ids include prompt_ids, we only need to decode the tokens after prompt_ids.
output_text = tokenizer.decode(generated_ids[len(model_inputs.input_ids[0]):], skip_special_tokens=True)

print(f"Prompt: {input_text}\n\nGenerated text: `{output_text}`")
print(f"Code:\n{input_text.replace('<fim_suffix>', output_text)}")

会输出:

Generated text: `     left = [x for x in arr if x < pivot]`
Code:
<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
     left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>

输入

input_text = """<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
    left =<fim_suffix>
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>"""

会输出

Generated text: `  [x for x in arr if x < pivot]`
Code:
<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
    left =  [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>

输入

input_text = """<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
    left = <fim_suffix>
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>"""

会输出

Generated text: ` [x for x in arr if x < pivot]`
Code:
<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
    left =  [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>
cyente commented 2 months ago

We sugges that you should format the indent spaces for the autocomplete suggestions for lines.

yuhon0528 commented 2 months ago

我也遇到了同样的问题,并不只是indent space的原因,请仔细查看@ytyt-yt 的例子。 专注在left =<fim_suffix>left = <fim_suffix>的例子上,模型都会给出两个空格的output如left = [x for x in arr if x < pivot]。 这会使得模型在做autocomplete的时候让user觉得麻烦,每次还要手动回去把多余空格去掉。 这是模型预训练导致的特性?有办法修复吗?

ytyt-yt commented 2 months ago

另外一个例子:

Prompt: <fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
    le<fim_suffix>
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>

Generated text: ` ft = [x for x in arr if x < pivot]`
Code:
<fim_prefix>def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = [len(arr) // 2]
    le ft = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)<fim_middle>

le|ft补全被断开成了le ft