pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

It doesn't accelerate very well at L4 #185

Open songh11 opened 1 week ago

songh11 commented 1 week ago

I'm glad the torch.compile is speeding up very quickly. On A5000 it can speed up 60%, but there's no acceleration at l4. I want to know why is it happen? Here is my code, you can set --compile when run this code:

import time
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
from transformers import set_seed
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def print_separater():
    print("=" * 20, "\n")

def get_model_and_tokenizer(model_path, device, dtype):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=dtype,
        device_map=device
    )
    model.tokenizer = tokenizer
    return model, tokenizer

def benchmark_throughput(model, model_inputs, args):
    device = model.device
    set_seed(args.seed)

    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()
    greedy_output = model.generate(
        **model_inputs,
        max_new_tokens=args.max_new_tokens,
        do_sample=args.do_sample,
        top_k=args.top_k,
        temperature=args.temperature,
        output_scores=True,
        return_dict_in_generate=True,
        use_cache=True,
    ).sequences
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()

    time_elasped = t1 - t0
    num_tokens = greedy_output.numel() - model_inputs['input_ids'].numel()

    print("Output:\n" + 100 * '-')
    print(model.tokenizer.decode(greedy_output[0], skip_special_tokens=False))

    print("Generated Tokens:", num_tokens)
    print("Time Elasped (s):", time_elasped)
    throughput = num_tokens/ time_elasped

    return throughput

def main(args):
    print("torch and transformer version:", torch.__version__, transformers.__version__)
    print(torch.__config__.parallel_info())
    print(f"device: {args.device}, dtype: {args.dtype}")
    print(f"model: {args.model_path}")
    print_separater()

    model, tokenizer = get_model_and_tokenizer(args.model_path, args.device, args.dtype)
    model_inputs = tokenizer(args.prompt, return_tensors='pt').to(args.device)

    warm_up_tokens = 20
    set_seed(args.seed)
    warm_up_output = model.generate(**model_inputs, max_new_tokens=warm_up_tokens)

    throughput = benchmark_throughput(model, model_inputs, args)
    print("throughput eager (token/s):", throughput)

    if args.compile:
        t0 = time.time()
        model._static_cache = StaticCache(
            config=model.config,
            max_batch_size=1,
            max_cache_len=4096,
            device=model.device,
            dtype=torch.float16,
        )
        model.model.forward = torch.compile(
            model.model.forward,
            backend=args.dynamo_backend,
            mode=args.dynamo_mode,
            dynamic=None,
            fullgraph=True,
            disable=False
            )
        t1 = time.time()
        print("Compile time (s):", t1 - t0)

        set_seed(args.seed)
        warm_up_output_compiled = model.generate(
            **model_inputs, max_new_tokens=warm_up_tokens)
        print("Warm-up result agree:", torch.equal(warm_up_output, warm_up_output_compiled))
        print_separater()

        throughput_compiled = benchmark_throughput(model, model_inputs, args)
        print_separater()
        print("compile speed-up:", throughput_compiled / throughput)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Your CLI description.')

    parser.add_argument('--device', type=str,
                        default="cuda")
    parser.add_argument('--dtype', default=torch.float16)
    parser.add_argument('--model_path', type=str,
                        default="meta-llama/Meta-Llama-3-8B", help='HF model name or path.')
    parser.add_argument('--prompt', type=str,
                        default="Q: What is the largest animal?\nA:", help='Input prompt.')
    parser.add_argument('--max_new_tokens', type=int,
                        default=256, help='Maximum number of new tokens.')
    parser.add_argument('--do_sample', action='store_true',
                        help='Whether to use sampling. Default is greedy search.')
    parser.add_argument('--top_k', type=int,
                        default=200, help='Top-k for sampling.')
    parser.add_argument('--temperature', type=float,
                        default=0.8, help='Temperature for sampling.')
    parser.add_argument('--compile', action='store_true',
                        help='Whether to compile the model.')
    parser.add_argument('--dynamo_backend', type=str,
                        default="inductor", help='torch._dynamo.list_backends()')
    parser.add_argument('--dynamo_mode', type=str,
                        default="default", help='["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')

    args = parser.parse_args()
    main(args)