mymusise / ChatGLM-Tuning

基于ChatGLM-6B + LoRA的Fintune方案
MIT License
3.73k stars 440 forks source link

如何像调用原本的ChatGLM一样生成对话 #133

Closed Ling-yunchi closed 1 year ago

Ling-yunchi commented 1 year ago

我看好像使用微调模型之后model变成了PeftModel,和原来的model类型不同。 我目前是使用这样的方法来将微调后的模型对外进行服务,但这样好像并不能加载历史数据。 还有一个疑问是为什么output需要替换掉原本的问题,而且当问题中包含中文标点时替换会无效,因为回答中的提问会变成英文标点符号。

@app.post("/lora")
async def lora(req: Request):
    global model, tokenizer
    json_post_raw = await req.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    input_text = json_post_list.get('prompt')
    ids = tokenizer.encode(input_text)
    input_ids = torch.cuda.LongTensor([ids])
    out = model.generate(
        input_ids=input_ids,
        max_length=2048,
        do_sample=False,
        temperature=0
    )
    out_text = tokenizer.decode(out[0])
    answer = out_text.replace(input_text, "").replace("\nEND", "").strip()
    # torch_gc()
    return {
        "response": answer
    }
cywjava commented 1 year ago

model.generate 方法,换成chat

cywjava commented 1 year ago

不要直接在generate上包装api服务 ,你会后悔的,如果同一时间高并发进来,你的显存会爆。。

suc16 commented 1 year ago

我看好像使用微调模型之后model变成了PeftModel,和原来的model类型不同。 我目前是使用这样的方法来将微调后的模型对外进行服务,但这样好像并不能加载历史数据。 还有一个疑问是为什么output需要替换掉原本的问题,而且当问题中包含中文标点时替换会无效,因为回答中的提问会变成英文标点符号。

@app.post("/lora")
async def lora(req: Request):
    global model, tokenizer
    json_post_raw = await req.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    input_text = json_post_list.get('prompt')
    ids = tokenizer.encode(input_text)
    input_ids = torch.cuda.LongTensor([ids])
    out = model.generate(
        input_ids=input_ids,
        max_length=2048,
        do_sample=False,
        temperature=0
    )
    out_text = tokenizer.decode(out[0])
    answer = out_text.replace(input_text, "").replace("\nEND", "").strip()
    # torch_gc()
    return {
        "response": answer
    }

问题一,你的input_text里应该包含history,需要自己去加一下,可以参考下官方库里的cli_demo.py。 问题二,是因为huggingface的generate方法,会在output里包含input,所以需要replace。也可以用out_text[len(input_text):]来实现只取answer。

FrankWhh commented 1 year ago

不要直接在generate上包装api服务 ,你会后悔的,如果同一时间高并发进来,你的显存会爆。。

所以要怎么封装api

suc16 commented 1 year ago

不要直接在generate上包装api服务 ,你会后悔的,如果同一时间高并发进来,你的显存会爆。。

所以要怎么封装api

api官方已经封好了都,不要调generate,直接调chat或者stream_chat

Ling-yunchi commented 1 year ago

不要直接在generate上包装api服务 ,你会后悔的,如果同一时间高并发进来,你的显存会爆。。

所以要怎么封装api

api官方已经封好了都,不要调generate,直接调chat或者stream_chat

我看了PeftModel的源码,并没有找到chat方法,请问有文档或者具体的调用方法或代码吗? 还是说是在这样之后

model = AutoModel.from_pretrained(
    "THUDM/chatglm-6b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(
    "THUDM/chatglm-6b", trust_remote_code=True)

peft_model = get_peft_model(model, peft_config)
peft_model.load_state_dict(torch.load(peft_path), strict=False)

调用原来model的chat方法?

suc16 commented 1 year ago

不要直接在generate上包装api服务 ,你会后悔的,如果同一时间高并发进来,你的显存会爆。。

所以要怎么封装api

api官方已经封好了都,不要调generate,直接调chat或者stream_chat

我看了PeftModel的源码,并没有找到chat方法,请问有文档或者具体的调用方法或代码吗? 还是说是在这样之后

model = AutoModel.from_pretrained(
    "THUDM/chatglm-6b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(
    "THUDM/chatglm-6b", trust_remote_code=True)

peft_model = get_peft_model(model, peft_config)
peft_model.load_state_dict(torch.load(peft_path), strict=False)

调用原来model的chat方法?

对的,直接调原模型的chat和stream_chat就行,参考 ChatGLM官方实现 1237行