WisdomShell / codeshell

A series of code large language models developed by PKU-KCL
http://se.pku.edu.cn/kcl
Other
1.61k stars 119 forks source link

CodeShell多batch 结果出错,如何支持多batch推理? #63

Open MeJerry215 opened 9 months ago

MeJerry215 commented 9 months ago

当前从examples 看是单batch的,如何能够使用多batch进行推理额,现在多batch 的结果好像不太对的样子。

from transformers import AutoModelForCausalLM, AutoTokenizer
import pdb
import torch
tokenizer = AutoTokenizer.from_pretrained("CodeShell-7B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("CodeShell-7B", torch_dtype=torch.float16, trust_remote_code=True).cuda()
examples = [
        "import math\ndef print_hello():",
        "import math\ndef quick_sort():",
        "import math\ndef test_quick_sort():",
        "import math\ndef test_print_hello():",
        "import math\ndef test_merge_sort():",
        "import math\ndef two_sum():",
        "import math\ndef preoder_transverse():",
        "import math\ndef merge_sort():",
    ]

inputs = tokenizer(examples, return_tensors='pt', padding=True)['input_ids'].cuda()
outputs = model.generate(inputs, max_new_tokens=128)
for output in outputs:
    print("=====================> ",tokenizer.decode(output))

测试代码如上,

测试结果有点奇怪

=====================>  import math
def print_hello():<|endoftext|><|endoftext|><|endoftext|><fim_prefix><fim_suffix>  }
    }
}<fim_middle>using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace _04.Longest_Increasing_Subsequence
{
    class Program
    {
        static void Main(string[] args)
        {
            int[] nums = Console.ReadLine().Split(' ').Select(int.Parse).ToArray();
            int[] len = new int[nums.Length];
            int[] prev = new int[nums.Length];
            int maxLen = 0;
=====================>  import math
def quick_sort():<|endoftext|><|endoftext|><|endoftext|><fim_prefix><fim_suffix>  }
    }
}<fim_middle>using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;