pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.35k stars 131 forks source link

Why is the inference speed of the quantized model using QAT so slow? #1050

Open elfisworking opened 1 day ago

elfisworking commented 1 day ago

i get a quantized model using torchtune package The test log show me: INFO:torchtune.utils._logging:Time for inference: 66.56 sec total, 4.51 tokens/sec 4.51 tokens/sec is even lower than that of the unquantized model. It is normal for QAT ? Thank you very much if someone is willing to answer.

andrewor14 commented 17 hours ago

Hi @elfisworking, which quantization scheme are you using, is it int8 dynamic activations + int4 weights, or int4 weight only? Currently we only have fast cuda kernels for the latter. The former still uses int8 to represent the weights (because PyTorch core doesn't have torch.int4 yet) and was designed primarily for executorch. However, I'm surprised to see it's slower than the unquantized model. Can you share the inference throughput for the unquantized model for comparison?

One other thing to try is skip QAT and quantize the model directly to see if you still observe the same, i.e.

from torchao.quantization.api import (
    int8_dynamic_activation_int4_weight,
    quantize_,
)
model = ...
quantize_(model, int8_dynamic_activation_int4_weight(group_size))
elfisworking commented 12 hours ago

hello, currently i use int8 dynamic activations + int4 weights. model is llama3-8B

elfisworking commented 9 hours ago

Thanks your reply @andrewor14 . i use this inference code for original model is

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time

model_name = "/QAT/Meta-Llama-3-8B"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

input_text = "What are the key features of LLaMA 3?"

inputs = tokenizer(input_text, return_tensors="pt").to(device)
t0 = time.perf_counter()

with torch.no_grad():
        outputs = model.generate(
                        inputs["input_ids"],
                        max_length=100,  
                        num_return_sequences=1,  
                        do_sample=True,  
                        top_k=50,  
                        top_p=0.95,  
                        temperature=1.0  
                        )
tokens = len(outputs.tolist()[0]) - len(inputs["input_ids"])
t = time.perf_counter() - t0
tokens_sec = tokens / t
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Time for inference: {t:.02f} sec totoal, {tokens_sec:.02f} tokens/sec, the number of tokens is {tokens}")
print("Generated Response:", output_text)

The output is

Time for inference: 6.08 sec totoal, 16.28 tokens/sec, the number of tokens is 99
Generated Response: What are the key features of LLaMA 3? In the end, GPT-4 is more than a chatbot. 25 min read. 1. A comparison of AI models: GPT-3 vs. GPT-4; GPT-3 vs. GPT-4: The differences; GPT-3 vs GPT-4: Similarities; A comparison of AI models: GPT-3 vs. GPT-4. We've also