bytedance / lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
Other
3.17k stars 328 forks source link

llama inference test #515

Open HandH1998 opened 1 year ago

HandH1998 commented 1 year ago

I build lightseq on cuda11.4 successfully. Then I do llama-13B inference test on A100-80G. I set max_step=1024. When max_batch_size <11, it works fine. The problem is that when I set max_batch_size >= 11, _lightseq/csrc/opsnew/sampling.cc.cu(73): an illegal memory access was encountered. And I also use CUDA_LAUNCH_BLOCKING=1 to locate the problem, _lightseq/csrc/opsnew/sampling.cc.cu(57): an illegal memory access was encountered.The memory uses about 40G, so it is not OOM problem. The following is my inference test script. Please help me with the problem.

import time
import argparse
import numpy as np
import torch
import lightseq.inference as lsi
from transformers import LlamaTokenizer, LlamaForCausalLM

def ls_llama(model, inputs):
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    results = model.infer(inputs)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    return results, end_time - start_time

def ls_generate(model, tokenizer, inputs):
    print("=========lightseq=========")
    print("lightseq generating...")
    ls_res_ids, ls_time = ls_llama(model, inputs)

    ls_res_ids = np.squeeze(ls_res_ids, axis=1)
    # ls_res = tokenizer.batch_decode(ls_res_ids, skip_special_tokens=True)
    ls_res = tokenizer.batch_decode(ls_res_ids)
    print("lightseq results:")
    for sent in ls_res:
        print(sent)

    input_seq_len = inputs.shape[1]
    input_bsz = inputs.shape[0]
    input_total_tokens = input_seq_len * input_bsz

    print("input_seq_len: {}".format(input_seq_len))
    print("input_bsz: {}".format(input_bsz))
    print("input_total_tokens: {}".format(input_total_tokens))

    output_total_tokens = ls_res_ids.size
    gen_total_tokens = output_total_tokens - input_total_tokens
    output_seq_len = [seq.size for seq in ls_res_ids]

    print("output_total_tokens: {}".format(output_total_tokens))
    print("output_seq_len: {}".format(output_seq_len))
    print("gen_total_tokens: {}".format(gen_total_tokens))
    print(f"lightseq time: {ls_time}s")
    print("gen_speed: {} tokens/s".format(gen_total_tokens / ls_time))

def warmup(ls_tokenizer, ls_model, sentences):
    ls_inputs = ls_tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"]
    ls_generate(ls_model, ls_tokenizer, ls_inputs)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--user_input", action="store_true")
    args = parser.parse_args()
    print("initializing gpt tokenizer...")
    ls_tokenizer = LlamaTokenizer.from_pretrained(
        "/home/zy/lightseq/llama/13b"
    )
    ls_tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    print("creating lightseq model...")
    # llama_weight_path = "/home/zy/lightseq/llama_13b.hdf5"
    ls_model = lsi.Llama(llama_weight_path, max_batch_size=11)

    # lightseq gpt perplexity supports batch infer with different lengths,
    # but sampling doesn't support
    sentences = [
        "Are you a pig?",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "Are you a pig?",
        "I love you, but you say that",
        "I love you, but you say that",
    ]
    print("====================START warmup====================")
    warmup(
        ls_tokenizer,
        ls_model,
        sentences,
    )
    print("====================END warmup====================")

    while True:
        if args.user_input:
            sentences = [input("input the masked sentence:\n")]

        print("tokenizing the sentences...")

        ls_inputs = ls_tokenizer(sentences, return_tensors="pt", padding=True)[
            "input_ids"
        ]
        ls_generate(ls_model, ls_tokenizer, ls_inputs)

        if not args.user_input:
            break

if __name__ == "__main__":
    main()
ChristineSeven commented 1 year ago

use your code, i got this error, module 'lightseq.inference' has no attribute 'Llama' . could you tell how you bypass this? @HandH1998

HandH1998 commented 1 year ago

use your code, i got this error, module 'lightseq.inference' has no attribute 'Llama' . could you tell how you bypass this? @HandH1998

It seems that you didn't compile it correctly. image Change use_new_arch to ON.

ChristineSeven commented 1 year ago

@HandH1998 Thanks.