baichuan-inc / Baichuan-7B

A large-scale 7B pretraining language model developed by BaiChuan-Inc.
https://huggingface.co/baichuan-inc/baichuan-7B
Apache License 2.0
5.67k stars 504 forks source link

解决爆24G显存的方法 #44

Open cywjava opened 1 year ago

cywjava commented 1 year ago

官方代码测试:

(python3.8) [baichuan@localhost baichuan-7B]$ python3 generate.py The model weights are not tied. Please use the tie_weights method before using the infer_auto_device function. 登鹳雀楼->王之涣 夜雨寄北->李商隐 过零丁洋->文天祥 己亥杂诗(其五)->龚自珍

cywjava commented 1 year ago

解决24G 显存不足的问题,样例代码:

import os

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch

PRE_TRAINED_MODEL_PATH = "../model/"

# 程序入口
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_PATH, trust_remote_code=True)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id  # set as the <unk> token
    if tokenizer.pad_token_id == 64000:
        tokenizer.pad_token_id = 0  # for baichuan model (need fix)

    config = AutoConfig.from_pretrained(PRE_TRAINED_MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(PRE_TRAINED_MODEL_PATH, config=config, torch_dtype=torch.float16,
                                                 trust_remote_code=True, device_map="auto", low_cpu_mem_usage=True)
    with torch.autocast("cuda"):
        while True:
            try:
                input_txt = input("user:")
                inputs = tokenizer(input_txt, return_tensors='pt')
                inputs = inputs.to("cuda:0")
                response = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1)
                response = tokenizer.decode(response.cpu()[0], skip_special_tokens=True)
                print("bot:", response)
                torch.cuda.empty_cache()
            except Exception as e:
                print(e)
                break

if __name__ == '__main__':
    main()
honjiaxuan commented 1 year ago

你好,上面提出的 解决方法失效了,请问有新的方法吗?