Open songh11 opened 1 week ago
I'm glad the torch.compile is speeding up very quickly. On A5000 it can speed up 60%, but there's no acceleration at l4. I want to know why is it happen? Here is my code, you can set --compile when run this code:
import time import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache from transformers import set_seed import os os.environ["TOKENIZERS_PARALLELISM"] = "false" def print_separater(): print("=" * 20, "\n") def get_model_and_tokenizer(model_path, device, dtype): tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=dtype, device_map=device ) model.tokenizer = tokenizer return model, tokenizer def benchmark_throughput(model, model_inputs, args): device = model.device set_seed(args.seed) if device == "cuda": torch.cuda.synchronize() t0 = time.time() greedy_output = model.generate( **model_inputs, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, top_k=args.top_k, temperature=args.temperature, output_scores=True, return_dict_in_generate=True, use_cache=True, ).sequences if device == "cuda": torch.cuda.synchronize() t1 = time.time() time_elasped = t1 - t0 num_tokens = greedy_output.numel() - model_inputs['input_ids'].numel() print("Output:\n" + 100 * '-') print(model.tokenizer.decode(greedy_output[0], skip_special_tokens=False)) print("Generated Tokens:", num_tokens) print("Time Elasped (s):", time_elasped) throughput = num_tokens/ time_elasped return throughput def main(args): print("torch and transformer version:", torch.__version__, transformers.__version__) print(torch.__config__.parallel_info()) print(f"device: {args.device}, dtype: {args.dtype}") print(f"model: {args.model_path}") print_separater() model, tokenizer = get_model_and_tokenizer(args.model_path, args.device, args.dtype) model_inputs = tokenizer(args.prompt, return_tensors='pt').to(args.device) warm_up_tokens = 20 set_seed(args.seed) warm_up_output = model.generate(**model_inputs, max_new_tokens=warm_up_tokens) throughput = benchmark_throughput(model, model_inputs, args) print("throughput eager (token/s):", throughput) if args.compile: t0 = time.time() model._static_cache = StaticCache( config=model.config, max_batch_size=1, max_cache_len=4096, device=model.device, dtype=torch.float16, ) model.model.forward = torch.compile( model.model.forward, backend=args.dynamo_backend, mode=args.dynamo_mode, dynamic=None, fullgraph=True, disable=False ) t1 = time.time() print("Compile time (s):", t1 - t0) set_seed(args.seed) warm_up_output_compiled = model.generate( **model_inputs, max_new_tokens=warm_up_tokens) print("Warm-up result agree:", torch.equal(warm_up_output, warm_up_output_compiled)) print_separater() throughput_compiled = benchmark_throughput(model, model_inputs, args) print_separater() print("compile speed-up:", throughput_compiled / throughput) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Your CLI description.') parser.add_argument('--device', type=str, default="cuda") parser.add_argument('--dtype', default=torch.float16) parser.add_argument('--model_path', type=str, default="meta-llama/Meta-Llama-3-8B", help='HF model name or path.') parser.add_argument('--prompt', type=str, default="Q: What is the largest animal?\nA:", help='Input prompt.') parser.add_argument('--max_new_tokens', type=int, default=256, help='Maximum number of new tokens.') parser.add_argument('--do_sample', action='store_true', help='Whether to use sampling. Default is greedy search.') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--dynamo_backend', type=str, default="inductor", help='torch._dynamo.list_backends()') parser.add_argument('--dynamo_mode', type=str, default="default", help='["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]') parser.add_argument('--seed', type=int, default=42, help='Random seed.') args = parser.parse_args() main(args)
I'm glad the torch.compile is speeding up very quickly. On A5000 it can speed up 60%, but there's no acceleration at l4. I want to know why is it happen? Here is my code, you can set --compile when run this code: