TigerResearch / TigerBot

TigerBot: A multi-language multi-task LLM
https://www.tigerbot.com
Apache License 2.0
2.24k stars 194 forks source link

本地API访问,可否增加embedding功能和多卡推理功能 #122

Closed OpenHuShen closed 1 year ago

OpenHuShen commented 1 year ago

1、本地API访问,考虑增加embedding功能吗,做文本分类任务需要使用

2、我目前为支持多卡推理作的修改,不知是否正确./apps/api.py

if __name__ == "__main__":

    model_path = "./hf_weights/TigerResearch---tigerbot-70b-chat"
    # model_path = "TigerResearch/tigerbot-13b-chat"
    model_max_length = 4096
    print(f"loading model: {model_path}...")
    # 修改开始
    device = torch.cuda.current_device()
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto')
    # 修改结束
    # model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
    max_memory = get_balanced_memory(model)
    # device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=[""])
    # print("Using the following device map for the model:", device_map)
    # model = dispatch_model(model, device_map=device_map, offload_buffers=True)
    generation_config = GenerationConfig.from_pretrained(model_path)
    generation_kwargs = generation_config.to_dict()
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        cache_dir=None,
        model_max_length=model_max_length,
        padding_side="left",
        truncation_side="left",
        padding=True,
        truncation=True,
    )
    generation_kwargs["eos_token_id"] = tokenizer.eos_token_id
    generation_kwargs["pad_token_id"] = tokenizer.pad_token_id
    model.eval()

    uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
chentigerye commented 1 year ago
  1. api已开放embedding功能,参考: https://www.tigerbot.com/api-reference/embedding
  2. 可以按照readme中的infer代码,CUDA_VISIBLE_DEVICES=0,1,2,3,4 python infer.py --model_path ${MODEL_DIR}