mymusise / ChatGLM-Tuning

基于ChatGLM-6B + LoRA的Fintune方案
MIT License
3.71k stars 441 forks source link

大佬们,能提供api.py吗?类似https://github.com/THUDM/ChatGLM-6B/blob/main/api.py #168

Open cristianohello opened 1 year ago

cristianohello commented 1 year ago

lora微调后的的模型服务脚本???

xx-zhang commented 1 year ago

model, tokenlizer ,手动改改就是一样的了啊, 就是加了个pt, 一样的、你看看脚本里面的内容粘贴改改也可以

Ling-yunchi commented 1 year ago
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn
import json
import datetime
import torch
from peft import get_peft_model, LoraConfig, TaskType

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE

def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

app = FastAPI()

@app.post("/")
async def chat(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history')
    max_length = json_post_list.get('max_length')
    top_p = json_post_list.get('top_p')
    temperature = json_post_list.get('temperature')
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length if max_length else 2048,
                                   top_p=top_p if top_p else 0.7,
                                   temperature=temperature if temperature else 0.95,
                                   do_sample=False)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    log = "[" + time + "] " + '", prompt:"' + \
        prompt + '", response:"' + repr(response) + '"'
    print(log)
    torch_gc()
    return answer

if __name__ == '__main__':
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    tokenizer = AutoTokenizer.from_pretrained(
        "THUDM/chatglm-6b", trust_remote_code=True)
    model = AutoModel.from_pretrained(
        "THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

    peft_path = "output/you/train/model/adapter_model.bin"

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=True,
        r=8,
        lora_alpha=32, lora_dropout=0.1
    )

    model = get_peft_model(model, peft_config)
    model.load_state_dict(torch.load(peft_path), strict=False)

    model.eval()

    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
suc16 commented 1 year ago
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn
import json
import datetime
import torch
from peft import get_peft_model, LoraConfig, TaskType

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE

def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

app = FastAPI()

@app.post("/")
async def chat(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history')
    max_length = json_post_list.get('max_length')
    top_p = json_post_list.get('top_p')
    temperature = json_post_list.get('temperature')
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length if max_length else 2048,
                                   top_p=top_p if top_p else 0.7,
                                   temperature=temperature if temperature else 0.95,
                                   do_sample=False)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    response = engine.process(response)
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    log = "[" + time + "] " + '", prompt:"' + \
        prompt + '", response:"' + repr(response) + '"'
    print(log)
    torch_gc()
    return answer

if __name__ == '__main__':
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    tokenizer = AutoTokenizer.from_pretrained(
        "THUDM/chatglm-6b", trust_remote_code=True)
    model = AutoModel.from_pretrained(
        "THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

    peft_path = "output/you/train/model/adapter_model.bin"

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=True,
        r=8,
        lora_alpha=32, lora_dropout=0.1
    )

    model = get_peft_model(model, peft_config)
    model.load_state_dict(torch.load(peft_path), strict=False)

    model.eval()

    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

太强了

cristianohello commented 1 year ago

@Ling-yunchi

大佬太强了!要是能提gradio更好啦

suc16 commented 1 year ago

@Ling-yunchi

大佬太强了!要是能提gradio更好啦

类似大佬这个前后端分离的,可以去看看fastchat

cristianohello commented 1 year ago

@suc16 fastchat???

suc16 commented 1 year ago

@suc16 fastchat???

应该叫参考一下fastchat fastchat的这个server,应该是前后端分离的比较好的 https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/gradio_web_server.py

mymusise commented 1 year ago

这个repo好像也集成了chatglm了,可以参考下: https://github.com/oobabooga/text-generation-webui

suc16 commented 1 year ago

这个repo好像也集成了chatglm了,可以参考下: https://github.com/oobabooga/text-generation-webui

这个repo确实更便于参考,stream_generate的api也实现了,fastchat改动难度有点大

cristianohello commented 1 year ago

浏览器打不开http://127.0.0.1:800,怎么调用?

INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004 (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):

response = engine.process(response)

NameError: name 'engine' is not defined

cristianohello commented 1 year ago

@suc16 报错? 浏览器打不开http://127.0.0.1:800,怎么调用?

INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004/ (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):

response = engine.process(response) NameError: name 'engine' is not defined

cristianohello commented 1 year ago

@mymusise 你的是model.generate,他的是response, history = model.chat ,curl -X POST "http://127.0.0.1:8000" \ -H 'Content-Type: application/json' \ -d '{"prompt": "你好", "history": []}'

部署的时候没有成功调用lora微调后的模型,还是调用的官方老模型

cristianohello commented 1 year ago

@Ling-yunchi 你的是model.generate,他的是response, history = model.chat ,curl -X POST "http://127.0.0.1:8000/" -H 'Content-Type: application/json' -d '{"prompt": "你好", "history": []}'

部署的时候没有成功调用lora微调后的模型,还是调用的官方老模型

而且 INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004/ (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):

response = engine.process(response) NameError: name 'engine' is not defined

Ling-yunchi commented 1 year ago

@Ling-yunchi 你的是model.generate,他的是response, history = model.chat ,curl -X POST "http://127.0.0.1:8000/" -H 'Content-Type: application/json' -d '{"prompt": "你好", "history": []}'

部署的时候没有成功调用lora微调后的模型,还是调用的官方老模型

而且 INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004/ (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):

response = engine.process(response) NameError: name 'engine' is not defined

response = engine.process(response)这一行删掉即可

cristianohello commented 1 year ago

@Ling-yunchi

调用的时候这样写吗?这样写没有调用自己微调的lora模型。 curl -X POST "http://127.0.0.1:8000" \ -H 'Content-Type: application/json' \ -d '{"prompt": "你好", "history": []}'

是不是

cristianohello commented 1 year ago

@Ling-yunchi

我的推理代码如下:

from transformers import AutoModel import torch from transformers import AutoTokenizer from peft import PeftModel

model = AutoModel.from_pretrained("../chatglm_models", trust_remote_code=True, load_in_8bit=True, device_map='auto')

tokenizer = AutoTokenizer.from_pretrained("../chatglm_models", trust_remote_code=True)

model = PeftModel.from_pretrained(model, "./output/")

import json

instructions = json.load(open("data/alpaca_data.json",encoding="utf-8"),strict=False)

answers = [] from cover_alpaca2jsonl import format_example

with torch.no_grad(): for idx, item in enumerate(instructions[50:60]): feature = format_example(item) input_text = feature['context'] ids = tokenizer.encode(input_text) input_ids = torch.LongTensor([ids]) out = model.generate( input_ids=input_ids, max_length=768, do_sample=False, temperature=0 ) out_text = tokenizer.decode(out[0]) answer = out_text.replace(input_text, "").replace("\nEND", "").strip() item['infer_answer'] = answer print(out_text) print(f"### {idx+1}.Answer:\n", item.get('output'), '\n\n') answers.append({'index': idx, **item})

针对这个,api部署的时候,没有加载lora的模型,一直都是加载官方的模型