OpenNMT / CTranslate2

Fast inference engine for Transformer models
https://opennmt.net/CTranslate2
MIT License
3.02k stars 268 forks source link

Weird speed behavior int8* quantization #1482

Open b-joris opened 9 months ago

b-joris commented 9 months ago

Model: Llama-2-7b-chat CT2 version: 3.19.0

I found that when I use a int8* quantization, the inference speed drastically depends on the num_hypotheses. I tried to benchmark my model with a batch size from 1 to 50 and here's my results (int8_bfloat16 quant):

int8_bfloat16 { "1": 44.00054508884497, "2": 43.835354729781585, "3": 43.35097676681637, "4": 42.951930925828, "5": 42.47537162085483, "6": 41.98050812826082, "7": 41.23707873776788, "8": 40.545429908337866, "9": 27.63858542971084, "10": 25.42197518642908, "11": 25.202539121999276, "12": 26.089053511836536, "13": 26.09313176242893, "14": 25.921680235987616, "15": 25.685515962201304, "16": 25.380245130497627, "17": 37.63648096142843, "18": 39.836009864633866, "19": 39.708816021166086, "20": 39.47367309657473, "21": 39.230580761787884, "22": 39.090078929628305, "23": 38.89884011212886, "24": 38.64880032519217, "25": 38.46047213226564, "26": 38.22022915531032, "27": 38.114739792440346, "28": 37.93036231586874, "29": 37.71376468599755, "30": 37.50896798657394, "31": 37.33254116467707, "32": 37.18685640533582, "33": 36.95990834536584, "34": 36.82252135352815, "35": 36.64739759644922, "36": 36.48969796991286, "37": 36.41474859724064, "38": 36.254421685401354, "39": 36.083237280900576, "40": 35.956776897190636, "41": 35.792959478761205, "42": 35.644916529250914, "43": 35.444568931816654, "44": 35.330089066182616, "45": 35.16173908808134, "46": 35.053941098624676, "47": 34.77292105398752, "48": 34.6240710017728, "49": 34.53410552029019, "50": 34.38873257035413 }

There's a big down with a num_hypothesesof 9 and it goes up again at 17. It happens with every int8* quant.

Here's the results of my benchmark without quantizartion:

no quant { "1": 30.930648781040798, "2": 29.929994382199716, "3": 29.902976810412603, "4": 29.78251596391624, "5": 29.688504757894574, "6": 29.619557073712652, "7": 29.910063893809284, "8": 29.795712678226863, "9": 29.69425672203113, "10": 29.606443544096074, "11": 29.472021254856028, "12": 29.373632894893838, "13": 29.23260503579981, "14": 29.146748733070016, "15": 27.925904276768797, "16": 27.81529887893288, "17": 26.558346354306416, "18": 26.097439228939407, "19": 25.985363181819594, "20": 25.556150043667827, "21": 25.195596565648774, "22": 25.25230550754091, "23": 24.967318153686055, "24": 25.087388468109868, "25": 24.891404356674084, "26": 24.396467669043645, "27": 24.348243698908185, "28": 24.34250739412098, "29": 23.249711253168524, "30": 23.453669645427496, "31": 23.425460399770433, "32": 23.310252816770962, "33": 23.619335656770097, "34": 23.268565762351656, "35": 24.04286048660994, "36": 23.4572124882483, "37": 23.25701410288176, "38": 23.345406567224476, "39": 23.3453323392156, "40": 23.24814228228635, "41": 23.17141933032643, "42": 23.18227425739279, "43": 23.464761977795924, "44": 23.402247207530408, "45": 23.173429748587893, "46": 23.254895517322797, "47": 23.138839982700087, "48": 23.163136816614028, "49": 23.153875672133612, "50": 22.677721002052017 }
Here is my code def generate_generator(self, prompt: str) -> tuple[int, Iterable[Iterable[int]]]: input_tokens = self.tokenizer.convert_ids_to_tokens( self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(DEVICE)[0] ) outputs = self.generator.generate_batch( [input_tokens], max_length=self.max_tokens, sampling_topk=50, sampling_topp=1, sampling_temperature=1, num_hypotheses=self.batch_size, beam_size=1, include_prompt_in_result=False, ) sequences = outputs[0].sequences numbers_of_tokens = len(max(sequences, key=lambda x: len(x))) outputs_tokens = [self.tokenizer.convert_tokens_to_ids(o) for o in sequences] return numbers_of_tokens, outputs_tokens

The speed smoothly decreases while increasing the num_hypotheses. Is it a normal behavior ?

jgcb00 commented 9 months ago

So I have also unexpected results, Here I'm testing with num_hypotheses of 1, and increasing the batch size, with several different GPU, and several different Llama 2 models, I can reproduce the issue

I have no logical reason here, it appends somewhere between batch_size of 16 or 17. It happens only with int8 quantitized model

I attached all my results Benchmark - LLM.xlsx