FreedomIntelligence / HuatuoGPT-II

HuatuoGPT2, One-stage Training for Medical Adaption of LLMs. (An Open Medical GPT)
370 stars 60 forks source link

这个没有上下文联系吗? #20

Open fei1025 opened 10 months ago

fei1025 commented 10 months ago

我现在用的代码

async def flush_stream(msg:Msg,request: Request):
    #param=await request.json()
    print(msg.question)
    print(msg.history)
    if(msg.question):
        async def event_generator(request: Request):
            query=msg.question
            print(query)
            history = msg.history
            prompt = generate_prompt(query, history)
            print(f"prompt{prompt}")
            inputs = tokenizer([prompt], return_tensors="pt")
            inputs = inputs.to(model.device)

            streamer = TextIteratorStreamer(tokenizer,skip_prompt=True)
            gen_kwargs = {'max_new_tokens': 1024, 'do_sample':True, 'top_p':0.7, 'temperature':0.3, 'repetition_penalty':1.1}
            generation_kwargs = dict(input_ids=inputs['input_ids'], streamer=streamer, **gen_kwargs)

            thread = Thread(target=model.generate, kwargs=generation_kwargs)
            thread.start()

            generated_text = ''
            i=0
            for new_text in streamer:
                print(new_text)
                if await request.is_disconnected():
                    print("连接已中断")
                    break
                if sep in new_text:
                    new_text = remove_overlap(generated_text,new_text[:-len(sep)])
                    for char in new_text:
                        generated_text += char
                        data = json.dumps({"id": i, "message": char}, ensure_ascii=False)
                        yield data
                    i+=1
                    break
                for char in new_text:
                    generated_text += char
                    i+=1
                    #print(char,end='',flush = True)
                    data = json.dumps({"id": i, "message": char}, ensure_ascii=False)
                    yield data
        return EventSourceResponse(event_generator(request))

def generate_prompt(query, history):
    if not history:
        return  f"<问>:{query}\n<答>:"
    else:
        prompt = ''
        for i, (old_query, response) in enumerate(history):
            prompt += "<问>:{}\n<答>:{}\n".format(old_query, response)
        prompt += "<问>:{}\n<答>:".format(query)
        return prompt

def remove_overlap(str1, str2):
        for i in range(len(str1), -1, -1): 
            if str1.endswith(str2[:i]): 
                return str2[i:] 
        return str2

但是这样写没有上下文的联系,请问有相应的api吗?还是现在不支持这个功能

jymChen commented 8 months ago

@fei1025 你好, 你的代码是不是没有更新history,可以考虑加上msg.history.append([query, generated_text])