pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.56k stars 169 forks source link

Why is the inference speed of the quantized model using QAT so slow? #1050

Open elfisworking opened 1 month ago

elfisworking commented 1 month ago

i get a quantized model using torchtune package The test log show me: INFO:torchtune.utils._logging:Time for inference: 66.56 sec total, 4.51 tokens/sec 4.51 tokens/sec is even lower than that of the unquantized model. It is normal for QAT ? Thank you very much if someone is willing to answer.

andrewor14 commented 1 month ago

Hi @elfisworking, which quantization scheme are you using, is it int8 dynamic activations + int4 weights, or int4 weight only? Currently we only have fast cuda kernels for the latter. The former still uses int8 to represent the weights (because PyTorch core doesn't have torch.int4 yet) and was designed primarily for executorch. However, I'm surprised to see it's slower than the unquantized model. Can you share the inference throughput for the unquantized model for comparison?

One other thing to try is skip QAT and quantize the model directly to see if you still observe the same, i.e.

from torchao.quantization.api import (
    int8_dynamic_activation_int4_weight,
    quantize_,
)
model = ...
quantize_(model, int8_dynamic_activation_int4_weight(group_size))
elfisworking commented 1 month ago

hello, currently i use int8 dynamic activations + int4 weights. model is llama3-8B

elfisworking commented 1 month ago

Thanks your reply @andrewor14 . i use this inference code for original model is

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time

model_name = "/QAT/Meta-Llama-3-8B"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

input_text = "What are the key features of LLaMA 3?"

inputs = tokenizer(input_text, return_tensors="pt").to(device)
t0 = time.perf_counter()

with torch.no_grad():
        outputs = model.generate(
                        inputs["input_ids"],
                        max_length=100,  
                        num_return_sequences=1,  
                        do_sample=True,  
                        top_k=50,  
                        top_p=0.95,  
                        temperature=1.0  
                        )
tokens = len(outputs.tolist()[0]) - len(inputs["input_ids"])
t = time.perf_counter() - t0
tokens_sec = tokens / t
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Time for inference: {t:.02f} sec totoal, {tokens_sec:.02f} tokens/sec, the number of tokens is {tokens}")
print("Generated Response:", output_text)

The output is

Time for inference: 6.08 sec totoal, 16.28 tokens/sec, the number of tokens is 99
Generated Response: What are the key features of LLaMA 3? In the end, GPT-4 is more than a chatbot. 25 min read. 1. A comparison of AI models: GPT-3 vs. GPT-4; GPT-3 vs. GPT-4: The differences; GPT-3 vs GPT-4: Similarities; A comparison of AI models: GPT-3 vs. GPT-4. We've also
andrewor14 commented 1 month ago

Hi @elfisworking, by the way did you explicitly quantize the model after QAT? The model produced at the end of training is still in high precision (e.g. bf16). To get an actual quantized model you need to do tune run quantize (or call quantizer.convert if you're using torchao APIs directly. If not then what you're seeing is actually the unquantized model with fake quantize ops inserted, and these fake quantize ops add some overhead to inference.

elfisworking commented 1 month ago

@andrewor14 thanks your reply. I am sure that i use the tune run qnantize command but model speed is still slow. What code can I provide to help you validate this issue? that's my torchtune inference log or this speed is normal ? thank you.

chat_format: null
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /QAT/output/llama3-8B/
  checkpoint_files:
  - meta_model_2-8da4w.pt
  model_type: LLAMA3
  output_dir: /QAT/output/llama3-8B/
device: cuda
dtype: bf16
enable_kv_cache: true
instruct_template: null
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3.llama3_8b
prompt: Tell me a joke?
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 42
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /QAT/Meta-Llama-3-8B/original/tokenizer.model
top_k: 1

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 42. Local seed is seed + rank = 42 + 0
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Starting compilation to improve generation performance ...
INFO:torchtune.utils._logging:Warmup run for quantized model takes: 208.78 sec
INFO:torchtune.utils._logging:Tell me a joke? I'll tell you one. What did the volcano say to the other volcano? "Let's lava some fun!" (Sorry, couldn't resist that one!) What did the tree say to the other tree? "I'm feeling a bit nuts!" (Sorry, couldn't resist that one either!) What did the fish say when he swam into
 a wall? "Dam!" (Sorry, couldn't resist that one either!) What did the snowman say to the other snowman? "I'm feeling a bit cold!" (Sorry, couldn't resist that one either!) What did the flower say to the other flower? "I'm feeling a bit petal!" (Sorry, couldn't resist that one either!) What did the bee say to the oth
er bee? "I'm feeling a bit buzzed!" (Sorry, couldn't resist that one either!) What did the spider say to the other spider? "I'm feeling a bit webby!" (Sorry, couldn't resist that one either!) What did the frog say to the other frog? "I'm feeling a bit croaky!" (Sorry, couldn't resist that one either!) What did the du
ck say to the other duck? "I'm feeling a bit quacky!" (Sorry, couldn't resist that one either!) What did the mouse say to the other mouse? "I'm feeling a bit mousy!" (Sorry, couldn't resist that one either!) What did the squirrel say to the other squirrel?
INFO:torchtune.utils._logging:Time for inference: 68.28 sec total, 4.39 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 71.56 GB/s
INFO:torchtune.utils._logging:Memory used: 17.41 GB
elfisworking commented 1 month ago

Should I use PTQ first and then QAT, rather than directly applying QAT to the original model?

andrewor14 commented 1 month ago

Sorry, let me clarify. The right flow for QAT is just two steps: prepare and convert. "Prepare" inserts fake quantizes (model is still in original high precision) and "convert" switches these fake quantizes into real quantizes (model is now in low precision, usually int). PTQ is a separate flow that just does the "convert" part.

In torchao, you should use the prepare and convert flow like this (full README here):

qat_quantizer = Int8DynActInt4WeightQATQuantizer()
model = qat_quantizer.prepare(model)
train(model)
model = qat_quantizer.convert(model)

In torchtune, this can be done with the following commands (for example):

# This fine-tunes the model using QAT in bf16 (prepare step)
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full

# This quantizes the QAT model to int8 dynamic activations + int4 weights (convert step)
tune run quantize --config recipes/configs/quantization.yaml

# This evaluates the quantized model (basically inference)
CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \
    checkpointer.checkpoint_files=<my_checkpoint_dir> \
    checkpointer.checkpoint_files=[meta_model_0-8da4w.pt] \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer

Are you running inference on the quantized model or the QAT model directly (in bf16)? In other words, did you run the tune run quantize step before running inference? That step is needed to get an actual quantized model. If not, the prepared model is very expensive because it contains a lot of extra fake quantize ops.

Also, just to add some context: the int8 dynamic activations + int4 weight flow currently doesn't have great cuda performance because there's no dedicated kernel. It was originally designed for lowering to edge devices through executorch, so you may not see significant speed/memory improvements. However, I think @alexsamardzic is working on such a cuda kernel, so an optimized version of this will be coming soon.

elfisworking commented 1 month ago

I have implemented the quantization process you mentioned, and now I am sure that the current speed is normal on the A100. I will patiently wait for the optimized version. Thank you very, very much for your reply.