mit-han-lab / smoothquant

[ICML 2023] SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models
https://arxiv.org/abs/2211.10438
MIT License
1.2k stars 138 forks source link

How to reproduce the performance described in the paper #48

Open rolex-cjj opened 1 year ago

rolex-cjj commented 1 year ago

I tested the latency of OPT-13B on a single NVIDIA A100-80GB GPU using a PyTorch implementation. With batch=1, input seq length=512, and output seq length=1, the fp16 latency is about 74.60ms and the SmoothQuant-O3 latency is about 78.192ms. This does not show the performance improvement mentioned in the paper. Is there something wrong?

rolex-cjj commented 1 year ago
import torch
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Tokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import W8A8Linear
import os
import gc
from torch.nn.functional import pad

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

class Evaluator:
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples['text'])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type='torch', columns=['input_ids'])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        latency = 0
        for batch in self.dataset:
            input_ids = batch['input_ids'].cuda().unsqueeze(0)
            label = input_ids[:, -1]
            pad_len = 512 - input_ids.shape[1]
            input_ids = pad(input_ids, (0, pad_len), value=1)
            torch.cuda.synchronize()
            start.record()
            outputs = model(input_ids)
            end.record()
            torch.cuda.synchronize()
            latency += start.elapsed_time(end)
            last_token_logits = outputs.logits[:, -2-pad_len, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()

        acc = hit / total
        lantecy = latency / len(self.dataset)
        return acc, lantecy

def print_model_size(model):
    # https://discuss.pytorch.org/t/finding-model-size/130275
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('Model size: {:.3f}MB'.format(size_all_mb))

from datasets import load_dataset
dataset = load_dataset('dataset/lambada', split='validation[:1000]')

tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-13b')
evaluator = Evaluator(dataset, tokenizer)

# fp16 opt-13b
model_fp16 = OPTForCausalLM.from_pretrained(
    'facebook/opt-13b', torch_dtype=torch.float16, device_map='auto')
acc_fp16, latency_fp16 = evaluator.evaluate(model_fp16)
print(f'OPT-13B FP16 accuracy: {acc_fp16}, per-sample lantecy: {latency_fp16:.3f}ms')

from smoothquant.opt import Int8OPTForCausalLM
model_smoothquant = Int8OPTForCausalLM.from_pretrained("mit-han-lab/opt-13b-smoothquant", torch_dtype=torch.float16, device_map='auto')
acc_smoothquant, lantecy_smoothquant = evaluator.evaluate(model_smoothquant)
print(
    f'OPT-13B SmoothQuant INT8 accuracy: {acc_smoothquant}, per-sample lantecy: {lantecy_smoothquant:.3f}ms')

OPT-13B FP16 accuracy: 0.786, per-sample lantecy: 74.602ms OPT-13B SmoothQuant INT8 accuracy: 0.786, per-sample lantecy: 78.192ms

WelY1 commented 1 year ago

I tested the latency and model size of opt-13b on a single V100 GPU. I use get_act_scales to get opt-13b.pt. Here is my results.