GeneZC / MiniMA

Code for paper titled "Towards the Law of Capacity Gap in Distilling Language Models"
Apache License 2.0
96 stars 5 forks source link

The inference latency of model MiniMA-3B #6

Closed qxpBlog closed 7 months ago

qxpBlog commented 7 months ago

@GeneZC Why the inference latency of model MiniMA-3B is longer than model Llama-7B: image image

GeneZC commented 7 months ago

Where do you obtain the results? We have not conducted such experiments in our paper.

qxpBlog commented 7 months ago

Where do you obtain the results? We have not conducted such experiments in our paper.

I used model MiniMA-3B for inference on datasets PTB and Wikitext2,the latency in the table above represents the time spent on inference.

GeneZC commented 7 months ago

Somehow weird. Could you please check that settings of them are the same? For example, they are both using flash attention and they are both using kv cache.

GeneZC commented 7 months ago

And you can provide more information so that I can help you to identify the problem.

GeneZC commented 7 months ago

And one point that should be noted is that: if more tokens are generated, then the latency should be larger. If so, the latency would be better normalized by the length of the generated tokens.

qxpBlog commented 7 months ago

And you can provide more information so that I can help you to identify the problem.

Here is the code for measuring latency:

import os
import sys
import argparse
import accelerate
from accelerate.utils import BnbQuantizationConfig
import torch
import numpy as np
import time
import transformers 
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer,AutoModel,AutoTokenizer,AutoModelForCausalLM,GPTQConfig
from codecarbon import track_emissions,EmissionsTracker
from LLMPruner.utils.logger import LoggerWithDepth
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import W8A8Linear
from ptflops import get_model_complexity_info
from ptflops.pytorch_ops import bn_flops_counter_hook, pool_flops_counter_hook
from LLMPruner.evaluator.ppl import PPLMetric,test_latency_energy
from LLMPruner.models.hf_llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm, LlamaAttention, LlamaMLP
from LLMPruner.peft import PeftModel
from transformers import DistilBertTokenizer, DistilBertModel,BitsAndBytesConfig
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
torch_version = int(torch.__version__.split('.')[1])

def LlamaAttention_counter_hook(module, input, output):
    # (1) Ignore past-key values
    # (2) Assume there is no attention mask
    # Input will be empty in some pytorch version. use output here since input.shape == output.shape
    flops = 0
    q_len = output[0].shape[1]
    linear_dim = output[0].shape[-1]
    num_heads = module.num_heads
    head_dim = module.head_dim

    rotary_flops = 2 * (q_len * num_heads * head_dim) * 2
    attention_flops = num_heads * (q_len * q_len * head_dim + q_len * q_len + q_len * q_len * head_dim) #QK^T + softmax + AttentionV
    linear_flops = 4 * (q_len * linear_dim * num_heads * head_dim) # 4 for q, k, v, o. 
    flops += rotary_flops + attention_flops + linear_flops
    module.__flops__ += int(flops)

def rmsnorm_flops_counter_hook(module, input, output):
    input = input[0]

    batch_flops = np.prod(input.shape)
    batch_flops *= 2
    module.__flops__ += int(batch_flops)

# 模型量化
def quantize_model(model, weight_quant='per_tensor', act_quant='per_tensor', quantize_bmm_input=True):
    for name, m in model.model.named_modules():
        if isinstance(m, OPTDecoderLayer):
            m.fc1 = W8A8Linear.from_float(m.fc1, weight_quant=weight_quant, act_quant=act_quant)
            m.fc2 = W8A8Linear.from_float(m.fc2, weight_quant=weight_quant, act_quant=act_quant)
        elif isinstance(m, OPTAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            m.q_proj = W8A8Linear.from_float(
                m.q_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.k_proj = W8A8Linear.from_float(
                m.k_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.v_proj = W8A8Linear.from_float(
                m.v_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)
            m.out_proj = W8A8Linear.from_float(m.out_proj, weight_quant=weight_quant, act_quant=act_quant)
    return model
def test_latency_energy(model, tokenizer, datasets, seq_len=128, batch_size = 4, device="cuda"):
    metric = {}
    _, test_loader_wikitext2 = get_loaders(datasets[0], tokenizer, seq_len=seq_len, batch_size = batch_size)
    _, test_loader_ptb = get_loaders(datasets[1], tokenizer, seq_len=seq_len, batch_size = batch_size)
    tracker = EmissionsTracker()
    tracker.start()
    start_time = time.time()
    ppl1 = llama_eval(model, test_loader_wikitext2, device)
    ppl2 = llama_eval(model, test_loader_ptb, device)
    end_time = time.time()
    tracker.stop()
    latency = end_time - start_time
    print(f"Model latency: {latency:.3f} seconds")
    metric[datasets[0]] = ppl1
    metric[datasets[1]] = ppl2
    return metric

@torch.no_grad()
def llama_eval(model, test_lodaer, device):
    nlls = []
    n_samples = 0
    for batch in tqdm(test_lodaer):
        batch = batch.to(device)
        output = model(batch)
        lm_logits = output.logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)
    #print(torch.cat(nlls, dim=-1).mean())
    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    return ppl.item()
# @track_emissions()
def main(args):

    if args.test_mod == 'base':
        tokenizer = LlamaTokenizer.from_pretrained(args.base_model)
        model = LlamaForCausalLM.from_pretrained(
            args.base_model,
            low_cpu_mem_usage=True
        )
    elif args.test_mod == 'distil':
        tokenizer = AutoTokenizer.from_pretrained("GeneZC/MiniMA-3B")
        model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniMA-3B", low_cpu_mem_usage=True)

    model.to(device)         
    model.config.pad_token_id = tokenizer.pad_token_id = 0 
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2

    model.eval()
    ppl = test_latency_energy(model, tokenizer, ['wikitext2', 'ptb'], args.max_seq_len, device=device)
    print("PPL after pruning: {}".format(ppl))
    print("Memory Requirement: {} MiB\n".format(torch.cuda.memory_allocated() / 1024 / 1024))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Tuning Pruned LLaMA (huggingface version)')

    parser.add_argument('--base_model', type=str, default="llama2-7b", help='base model name')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--lora_ckpt', type=str, default=None)
    parser.add_argument('--max_seq_len', type=int, default=128, help='max sequence length')
    parser.add_argument('--test_mod', type=str, default="tuned", help='choose from [pruned, tuned, base]')
    args = parser.parse_args()

    main(args)

I set the attribute max_seq_len to 128. The only difference in measuring the latency between Model MiniMA-3B and Model Llama2-7B is the way they are loaded.

GeneZC commented 7 months ago

I do not see clear reasons why the latency would be that large. How about trying loading MiniMA with LlamaForCausalLM instead since MiniMA is typically using the LLaMA architecture. And please kindly check whether flash attention is turned on for LLaMA but off for MiniMA.

qxpBlog commented 7 months ago

I do not see clear reasons why the latency would be that large. How about trying loading MiniMA with LlamaForCausalLM instead since MiniMA is typically using the LLaMA architecture. And please kindly check whether flash attention is turned on for LLaMA but off for MiniMA.

I have tried using LlamaForCausalLM.from_pretrained to load model MiniMA-3B, but it have not effect on the latency of model inference. So how can i to check whether flash attention is turned on for LLaMA or MiniMA.

GeneZC commented 7 months ago

By print(model.config._attn_implementation) after the model is loaded, you can find which attention is used.

qxpBlog commented 7 months ago

print(model.config._attn_implementation)

Thanks,the result of print(model.config._attn_implementation) is eager. Model MiniMA-3B and model Llama-2-7B is same.

qxpBlog commented 7 months ago

By print(model.config._attn_implementation) after the model is loaded, you can find which attention is used.

But i find that in terms of FLOPs , model MiniMA-3B is lower than model Llama-2-7B: image

So does this mean that model MiniMA-3B needs to speed more time on inference cmopared to model Llama-2-7B.

GeneZC commented 7 months ago

The FLOPs here in the Table are training FLOPs. However, MiniMA is supposed to also have a smaller inference FLOPs than LLaMA-7B does due to its smaller model scale, therefore a smaller latency than LLaMA-7B in expectation (if they are tested under exactly the same setting).

So I suspect there is still a diff somewhere uncovered. Perhaps the vocabulary size? MiniMA indeed has a slightly larger vocabulary than LLaMA-7B does (~50000 vs ~30000). But I have not expected the impact be that large.

qxpBlog commented 7 months ago

The FLOPs here in the Table are training FLOPs. However, MiniMA is supposed to also have a smaller inference FLOPs than LLaMA-7B does due to its smaller model scale, therefore a smaller latency than LLaMA-7B in expectation (if they are tested under exactly the same setting).

So I suspect there is still a diff somewhere uncovered. Perhaps the vocabulary size? MiniMA indeed has a slightly larger vocabulary than LLaMA-7B does (~50000 vs ~30000). But I have not expected the impact be that large.

I'm not sure.The vocab_size of MiniMA-3B is 49216, and Llama-2-7b is 32000.But the impact of vocab_size is too big. Can I directly modify the value of MiniMA-3B vocab_size to 32000. Afterwards, measure the inference delay.

GeneZC commented 7 months ago

Rather than modifying the vocabulary size, you can directly use LLaMA-7B tokenizer for MiniMA-3B and carry out a test since these two models share the very first 32000 tokens. Good luck!

qxpBlog commented 7 months ago

Rather than modifying the vocabulary size, you can directly use LLaMA-7B tokenizer for MiniMA-3B and carry out a test since these two models share the very first 32000 tokens. Good luck!

😂Directly using tokenizer of model Llama-2-7b is also not feasible.It seems that vocabulary size is not the decisive factor affecting inference time.Have you ever conducted experiments related to MiniMA-3B inference speed before

GeneZC commented 7 months ago

Let me have a try ; )

GeneZC commented 7 months ago

Here are the results I obtained in a similar way as yours: MiniMA-3B: 197.8612543 ms per batch LLaMA-2-7B: 392.4757415 ms per batch

The below is the code snippet:

import os
import argparse

import torch
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

def main(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    kwargs = {"torch_dtype": torch.bfloat16, "attn_implementation": "eager", "trust_remote_code": True}
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **kwargs)
    model = model.cuda()
    model.eval()

    with torch.no_grad():
        text = open(args.data_dir, "r").read()

        max_length = 512
        tmp_input_ids = tokenizer([text]).input_ids[0][1:]
        input_ids = []
        for i in range(0, len(tmp_input_ids), max_length):
            input_ids.append([tokenizer.bos_token_id] + tmp_input_ids[i: i + max_length])

        input_ids = input_ids[:-1]

        batch_size = 8
        beginner, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        latencies = []
        for i in range(0, len(input_ids), batch_size):
            beginner.record()
            _ = model(input_ids=torch.as_tensor(input_ids[i: i + batch_size]).cuda(), labels=torch.as_tensor(input_ids[i: i + batch_size]).cuda(), return_dict=False)[0].item()
            ender.record()
            torch.cuda.synchronize()
            latency = beginner.elapsed_time(ender)
            latencies.append(latency)
        print(latencies) 
        print("Average latency: {:.7f}".format(np.mean(latencies)))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="google/flan-t5-small",
    )
    parser.add_argument("--data_dir", type=str, default="wikitext-2-raw/wiki.test.raw")
    parser.add_argument("--output_dir", type=str, default="outputs") 
    args = parser.parse_args()

    main(args)
qxpBlog commented 7 months ago

Let me have a try ; )

Here are the results I obtained in a similar way as yours: MiniMA-3B: 197.8612543 ms per batch LLaMA-2-7B: 392.4757415 ms per batch

The below is the code snippet:

import os
import argparse

import torch
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

def main(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    kwargs = {"torch_dtype": torch.bfloat16, "attn_implementation": "eager", "trust_remote_code": True}
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **kwargs)
    model = model.cuda()
    model.eval()

    with torch.no_grad():
        text = open(args.data_dir, "r").read()

        max_length = 512
        tmp_input_ids = tokenizer([text]).input_ids[0][1:]
        input_ids = []
        for i in range(0, len(tmp_input_ids), max_length):
            input_ids.append([tokenizer.bos_token_id] + tmp_input_ids[i: i + max_length])

        input_ids = input_ids[:-1]

        batch_size = 8
        beginner, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        latencies = []
        for i in range(0, len(input_ids), batch_size):
            beginner.record()
            _ = model(input_ids=torch.as_tensor(input_ids[i: i + batch_size]).cuda(), labels=torch.as_tensor(input_ids[i: i + batch_size]).cuda(), return_dict=False)[0].item()
            ender.record()
            torch.cuda.synchronize()
            latency = beginner.elapsed_time(ender)
            latencies.append(latency)
        print(latencies) 
        print("Average latency: {:.7f}".format(np.mean(latencies)))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="google/flan-t5-small",
    )
    parser.add_argument("--data_dir", type=str, default="wikitext-2-raw/wiki.test.raw")
    parser.add_argument("--output_dir", type=str, default="outputs") 
    args = parser.parse_args()

    main(args)

Thanks,

Here are the results I obtained in a similar way as yours: MiniMA-3B: 197.8612543 ms per batch LLaMA-2-7B: 392.4757415 ms per batch

The below is the code snippet:

import os
import argparse

import torch
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

def main(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    kwargs = {"torch_dtype": torch.bfloat16, "attn_implementation": "eager", "trust_remote_code": True}
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **kwargs)
    model = model.cuda()
    model.eval()

    with torch.no_grad():
        text = open(args.data_dir, "r").read()

        max_length = 512
        tmp_input_ids = tokenizer([text]).input_ids[0][1:]
        input_ids = []
        for i in range(0, len(tmp_input_ids), max_length):
            input_ids.append([tokenizer.bos_token_id] + tmp_input_ids[i: i + max_length])

        input_ids = input_ids[:-1]

        batch_size = 8
        beginner, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        latencies = []
        for i in range(0, len(input_ids), batch_size):
            beginner.record()
            _ = model(input_ids=torch.as_tensor(input_ids[i: i + batch_size]).cuda(), labels=torch.as_tensor(input_ids[i: i + batch_size]).cuda(), return_dict=False)[0].item()
            ender.record()
            torch.cuda.synchronize()
            latency = beginner.elapsed_time(ender)
            latencies.append(latency)
        print(latencies) 
        print("Average latency: {:.7f}".format(np.mean(latencies)))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="google/flan-t5-small",
    )
    parser.add_argument("--data_dir", type=str, default="wikitext-2-raw/wiki.test.raw")
    parser.add_argument("--output_dir", type=str, default="outputs") 
    args = parser.parse_args()

    main(args)

Thank you for your answer, I know what the reason is now.It's due to the accuracy of the model.The accuracy of the two models is different.

GeneZC commented 7 months ago

You mean precision right? i.e., FP16 or BF16?

qxpBlog commented 7 months ago

You mean precision right? i.e., FP16 or BF16?

Yes, when i use torch_dtype=torch.float16, the latency of two model is normal. image image

GeneZC commented 7 months ago

That's also interesting. I did not expect the precision could impact the latency that much ; )

qxpBlog commented 7 months ago

That's also interesting. I did not expect the precision could impact the latency that much ; )

I also saw that your code has set specific precision, so I thought about trying it out. I didn't expect that the reason was indeed this

GeneZC commented 7 months ago

I see, good luck to your work!