huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.79k stars 1.03k forks source link

Add support for AWQ quantized models #781

Closed 0x1997 closed 12 months ago

0x1997 commented 1 year ago

Compared to GPTQ, AWQ is more accurate and has much better inference performance.

Benchmark: https://github.com/lm-sys/FastChat/blob/main/docs/awq.md#benchmark

~Note: Multi-Query Attention is not yet supported.~

abhinavkulkarni commented 1 year ago

I have released a bunch of AWQ quantized models here: https://huggingface.co/abhinavkulkarni?sort_models=downloads#models

Instructions on how to run these with HuggingFace API are in the model cards.

Narsil commented 1 year ago

Can anyone run benchmarks against TGI + exllama kernels ?

Those are supposed to provide a similar speedup.

We don't want to support every quantization scheme in TGI, just the best possible subset:

abhinavkulkarni commented 1 year ago

@Narsil: Any kernel optimizations done for GPTQ should translate to AWQ since they both are based on similar zero-point quantization schemes - they simply differ on how those exact zero-point weights are found and admittedly AWQ is superior to GPTQ.

So, someone needs to simply write a "translator" from AWQ to GPTQ state dicts and everything else should work as is.

edwardzjl commented 1 year ago

@Narsil I agree that we should evaluate before adding another quantization support, but it's difficult to perform a fair comparison.

A fair comparison should be performed between TGI + gptq + exllama and TGI + awq, but not between TGI + gptq + exllama and {some_other_inference_framework} + awq

Besides, awq can be applied to LLMs other than Llama, where exllama cannot.

Narsil commented 1 year ago

Besides, awq can be applied to LLMs other than Llama, where exllama cannot.

This statement is wrong exllama is just because the kernels where created here: https://github.com/turboderp/exllama it has nothing to do with llama. (In general current quantization techniques juste replace Linear with QuantLinear basically).

edwardzjl commented 1 year ago

Besides, awq can be applied to LLMs other than Llama, where exllama cannot.

This statement is wrong exllama is just because the kernels where created here: https://github.com/turboderp/exllama it has nothing to do with llama. (In general current quantization techniques juste replace Linear with QuantLinear basically).

My apologies, I made an assumption about exllama based solely on its name. I mistakenly thought it was specifically for Llama models. :sweat_smile:

Narsil commented 1 year ago

No worries.

casper-hansen commented 1 year ago

My own benchmark from AWQ is 134 tokens/s (7.46 ms/token) on a 4090+i9-13900k for MPT 7B models.

As Narsil mentions, quantization methods mostly replace Linear with QuantLinear layers. AWQ does this with their optimized GEMM kernel. Additionally, AWQ Tinychat runs the following optimizations for LLaMa models specifically:

LLaMa models are 100+ tokens/s.

Why AWQ is faster than GPTQ

AWQ is faster than GPTQ. It is not faster than exllama because exllama runs a lot of kernel optimizations on top to make it faster. But the problem is that exllama is written explicitly to optimize LLaMa models, so the full performance boost will not be seen in other models.

From the AWQ paper:

Different weight channels have different importance; updating the salient channels to compensate for the non-salient ones will likely destroy the performance. Reordering prevents it by quantizing important channels first. However, it will lead to bad hardware efficiency due to irregular memory access (Figure 2), while our scaling method does not suffer from the issue.

image

abhinavkulkarni commented 1 year ago

Hi,

I have added rudimentary support for AWQ models at https://github.com/abhinavkulkarni/text-generation-inference/tree/abhinavkulkarni/add-awq-support

You can view the side-by-side changes here.

This requires installing AWQ library and CUDA kernels for 4-bit matrix multiplication:

git clone https://github.com/mit-han-lab/llm-awq \
&& cd llm-awq \
&& git checkout ce4a6bb1c238c014a06672cb74f6865573494d66 \
&& pip install -e . \
&& cd awq/kernels \
&& python setup.py install

After that

git clone https://github.com/abhinavkulkarni/text-generation-inference.git \
&& cd text-generation-inference \
&& git checkout abhinavkulkarni/add-awq-support \
&& make install

I did upgrade to the latest versions: pip install --upgrade transformers accelerate bitsandbytes

I was able to run TGI as follows:

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq \
--trust-remote-code \
--port 8080 \
--max-input-length 4000 --max-total-tokens 4096 \
--quantize awq

This change of course borrows from AWQ library and for zero-point quantization, I use their WQLinear layer which is very similar to QuantLinear layer in TGI. For now, I have hardcoded the values of bits to 4 and groupsize to 128, but should be possible to read them off quantize_config.json. None of my models have quantize.json yet, but I'll update the model repos with one.

I don't think this change is comprehensive and I would welcome any pull requests.

The ideal scenario would be to subsume the logic of WQLinear from AWQ into QuantLinear of TGI, so that we can benefit from flash attention goodness.

Thanks!

CC: @casperbh96, @Narsil, @TheBloke

abhinavkulkarni commented 1 year ago

I benchmarked Llama 2 7B AWQ vs GPTQ with FlashAttention v1 and vLLM on RTX 3060 (12GB of VRAM). Note, I do not have exllama installed. Following are the results:

AWQ model_id: abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq GPTQ model_id: TheBloke/Llama-2-7b-Chat-GPTQ

Both models were run with --max-input-length 4000 --max-total-tokens 4096

GPTQ benchmarks:

| Parameter          | Value                         |
|--------------------|-------------------------------|
| Model              | TheBloke/Llama-2-7b-Chat-GPTQ |
| Sequence Length    | 10                            |
| Decode Length      | 8                             |
| N Runs             | 10                            |
| Warmups            | 1                             |
| Temperature        | None                          |
| Top K              | None                          |
| Top P              | None                          |
| Typical P          | None                          |
| Repetition Penalty | None                          |
| Watermark          | false                         |
| Do Sample          | false                         |

| Step           | Batch Size | Average   | Lowest    | Highest   | p50       | p90       | p99       |
|----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------|
| Prefill        | 1          | 41.23 ms  | 41.16 ms  | 41.39 ms  | 41.22 ms  | 41.39 ms  | 41.39 ms  |
|                | 2          | 47.80 ms  | 47.72 ms  | 47.88 ms  | 47.81 ms  | 47.88 ms  | 47.88 ms  |
|                | 4          | 57.94 ms  | 57.83 ms  | 58.02 ms  | 57.95 ms  | 58.02 ms  | 58.02 ms  |
|                | 8          | 108.53 ms | 108.39 ms | 108.81 ms | 108.56 ms | 108.81 ms | 108.81 ms |
|                | 16         | 153.65 ms | 153.22 ms | 156.46 ms | 153.35 ms | 156.46 ms | 156.46 ms |
|                | 32         | 251.93 ms | 251.04 ms | 252.23 ms | 252.05 ms | 252.23 ms | 252.23 ms |
| Decode (token) | 1          | 40.33 ms  | 40.27 ms  | 40.45 ms  | 40.32 ms  | 40.32 ms  | 40.32 ms  |
|                | 2          | 40.83 ms  | 40.80 ms  | 40.90 ms  | 40.84 ms  | 40.82 ms  | 40.82 ms  |
|                | 4          | 41.07 ms  | 40.81 ms  | 41.15 ms  | 41.10 ms  | 40.81 ms  | 40.81 ms  |
|                | 8          | 41.28 ms  | 41.25 ms  | 41.34 ms  | 41.28 ms  | 41.29 ms  | 41.29 ms  |
|                | 16         | 48.03 ms  | 47.92 ms  | 48.22 ms  | 48.04 ms  | 47.95 ms  | 47.95 ms  |
|                | 32         | 59.45 ms  | 59.35 ms  | 59.65 ms  | 59.42 ms  | 59.65 ms  | 59.65 ms  |
| Decode (total) | 1          | 282.34 ms | 281.92 ms | 283.14 ms | 282.27 ms | 282.25 ms | 282.25 ms |
|                | 2          | 285.83 ms | 285.61 ms | 286.33 ms | 285.86 ms | 285.76 ms | 285.76 ms |
|                | 4          | 287.48 ms | 285.70 ms | 288.08 ms | 287.68 ms | 285.70 ms | 285.70 ms |
|                | 8          | 288.99 ms | 288.73 ms | 289.37 ms | 288.97 ms | 289.00 ms | 289.00 ms |
|                | 16         | 336.21 ms | 335.45 ms | 337.57 ms | 336.28 ms | 335.63 ms | 335.63 ms |
|                | 32         | 416.15 ms | 415.43 ms | 417.57 ms | 415.96 ms | 417.57 ms | 417.57 ms |

| Step    | Batch Size | Average            | Lowest             | Highest            |
|---------|------------|--------------------|--------------------|--------------------|
| Prefill | 1          | 24.25 tokens/secs  | 24.16 tokens/secs  | 24.30 tokens/secs  |
|         | 2          | 41.84 tokens/secs  | 41.77 tokens/secs  | 41.92 tokens/secs  |
|         | 4          | 69.04 tokens/secs  | 68.94 tokens/secs  | 69.17 tokens/secs  |
|         | 8          | 73.71 tokens/secs  | 73.52 tokens/secs  | 73.81 tokens/secs  |
|         | 16         | 104.14 tokens/secs | 102.26 tokens/secs | 104.43 tokens/secs |
|         | 32         | 127.02 tokens/secs | 126.87 tokens/secs | 127.47 tokens/secs |
| Decode  | 1          | 24.79 tokens/secs  | 24.72 tokens/secs  | 24.83 tokens/secs  |
|         | 2          | 48.98 tokens/secs  | 48.89 tokens/secs  | 49.02 tokens/secs  |
|         | 4          | 97.40 tokens/secs  | 97.20 tokens/secs  | 98.00 tokens/secs  |
|         | 8          | 193.78 tokens/secs | 193.53 tokens/secs | 193.95 tokens/secs |
|         | 16         | 333.13 tokens/secs | 331.78 tokens/secs | 333.88 tokens/secs |
|         | 32         | 538.27 tokens/secs | 536.43 tokens/secs | 539.21 tokens/secs |

AWQ benchmarks:

| Parameter          | Value                                                     |
|--------------------|-----------------------------------------------------------|
| Model              | abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq |
| Sequence Length    | 10                                                        |
| Decode Length      | 8                                                         |
| N Runs             | 10                                                        |
| Warmups            | 1                                                         |
| Temperature        | None                                                      |
| Top K              | None                                                      |
| Top P              | None                                                      |
| Typical P          | None                                                      |
| Repetition Penalty | None                                                      |
| Watermark          | false                                                     |
| Do Sample          | false                                                     |

| Step           | Batch Size | Average   | Lowest    | Highest   | p50       | p90       | p99       |
|----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------|
| Prefill        | 1          | 18.84 ms  | 18.70 ms  | 19.25 ms  | 18.75 ms  | 19.25 ms  | 19.25 ms  |
|                | 2          | 31.18 ms  | 31.05 ms  | 31.37 ms  | 31.19 ms  | 31.37 ms  | 31.37 ms  |
|                | 4          | 46.88 ms  | 46.63 ms  | 47.25 ms  | 46.90 ms  | 47.25 ms  | 47.25 ms  |
|                | 8          | 78.74 ms  | 78.44 ms  | 79.09 ms  | 78.81 ms  | 79.09 ms  | 79.09 ms  |
|                | 16         | 154.59 ms | 154.09 ms | 154.96 ms | 154.75 ms | 154.96 ms | 154.96 ms |
|                | 32         | 308.17 ms | 307.61 ms | 308.79 ms | 308.21 ms | 308.79 ms | 308.79 ms |
| Decode (token) | 1          | 16.21 ms  | 16.11 ms  | 16.69 ms  | 16.14 ms  | 16.69 ms  | 16.69 ms  |
|                | 2          | 16.62 ms  | 16.54 ms  | 16.80 ms  | 16.63 ms  | 16.80 ms  | 16.80 ms  |
|                | 4          | 17.28 ms  | 17.18 ms  | 17.42 ms  | 17.31 ms  | 17.42 ms  | 17.42 ms  |
|                | 8          | 18.56 ms  | 18.52 ms  | 18.61 ms  | 18.56 ms  | 18.61 ms  | 18.61 ms  |
|                | 16         | 22.51 ms  | 21.77 ms  | 28.57 ms  | 21.86 ms  | 28.57 ms  | 28.57 ms  |
|                | 32         | 37.61 ms  | 37.58 ms  | 37.67 ms  | 37.61 ms  | 37.67 ms  | 37.67 ms  |
| Decode (total) | 1          | 113.47 ms | 112.78 ms | 116.80 ms | 113.01 ms | 116.80 ms | 116.80 ms |
|                | 2          | 116.37 ms | 115.81 ms | 117.60 ms | 116.43 ms | 117.60 ms | 117.60 ms |
|                | 4          | 120.99 ms | 120.27 ms | 121.94 ms | 121.15 ms | 121.94 ms | 121.94 ms |
|                | 8          | 129.91 ms | 129.65 ms | 130.25 ms | 129.91 ms | 130.25 ms | 130.25 ms |
|                | 16         | 157.60 ms | 152.36 ms | 199.98 ms | 153.04 ms | 199.98 ms | 199.98 ms |
|                | 32         | 263.28 ms | 263.03 ms | 263.70 ms | 263.27 ms | 263.70 ms | 263.70 ms |

| Step    | Batch Size | Average            | Lowest             | Highest            |
|---------|------------|--------------------|--------------------|--------------------|
| Prefill | 1          | 53.09 tokens/secs  | 51.94 tokens/secs  | 53.49 tokens/secs  |
|         | 2          | 64.14 tokens/secs  | 63.75 tokens/secs  | 64.41 tokens/secs  |
|         | 4          | 85.32 tokens/secs  | 84.66 tokens/secs  | 85.78 tokens/secs  |
|         | 8          | 101.60 tokens/secs | 101.16 tokens/secs | 101.99 tokens/secs |
|         | 16         | 103.50 tokens/secs | 103.25 tokens/secs | 103.83 tokens/secs |
|         | 32         | 103.84 tokens/secs | 103.63 tokens/secs | 104.03 tokens/secs |
| Decode  | 1          | 61.70 tokens/secs  | 59.93 tokens/secs  | 62.07 tokens/secs  |
|         | 2          | 120.31 tokens/secs | 119.05 tokens/secs | 120.89 tokens/secs |
|         | 4          | 231.43 tokens/secs | 229.62 tokens/secs | 232.81 tokens/secs |
|         | 8          | 431.08 tokens/secs | 429.94 tokens/secs | 431.94 tokens/secs |
|         | 16         | 715.32 tokens/secs | 560.06 tokens/secs | 735.11 tokens/secs |
|         | 32         | 850.79 tokens/secs | 849.43 tokens/secs | 851.62 tokens/secs |

Thanks!

Narsil commented 1 year ago

@abhinavkulkarni Can you try with exllama please?

It looks very promising !

abhinavkulkarni commented 1 year ago

Hey @Narsil,

I am unable to install Exllama GPTQ kernels even when I run BUILD_EXTENSIONS=True make install.

Do I need to install them separately?

Edit:

I installed exllama kernels by cd server/exllama_kernels and python setup.py install.

I do see log lines while loading the server:

2023-08-12T18:03:38.108460Z INFO text_generation_launcher: Using exllama kernels

I get worse results than before for GPTQ:

| Parameter          | Value                         |
|--------------------|-------------------------------|
| Model              | TheBloke/Llama-2-7b-Chat-GPTQ |
| Sequence Length    | 10                            |
| Decode Length      | 8                             |
| N Runs             | 10                            |
| Warmups            | 1                             |
| Temperature        | None                          |
| Top K              | None                          |
| Top P              | None                          |
| Typical P          | None                          |
| Repetition Penalty | None                          |
| Watermark          | false                         |
| Do Sample          | false                         |

| Step           | Batch Size | Average   | Lowest    | Highest   | p50       | p90       | p99       |
|----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------|
| Prefill        | 1          | 54.29 ms  | 54.10 ms  | 54.52 ms  | 54.24 ms  | 54.52 ms  | 54.52 ms  |
|                | 2          | 58.94 ms  | 58.84 ms  | 59.08 ms  | 58.92 ms  | 59.08 ms  | 59.08 ms  |
|                | 4          | 68.53 ms  | 68.19 ms  | 68.76 ms  | 68.61 ms  | 68.76 ms  | 68.76 ms  |
|                | 8          | 102.44 ms | 102.32 ms | 102.63 ms | 102.43 ms | 102.63 ms | 102.63 ms |
|                | 16         | 143.92 ms | 143.65 ms | 144.09 ms | 143.99 ms | 144.09 ms | 144.09 ms |
|                | 32         | 227.84 ms | 227.70 ms | 228.07 ms | 227.82 ms | 228.07 ms | 228.07 ms |
| Decode (token) | 1          | 31.17 ms  | 30.53 ms  | 36.57 ms  | 30.58 ms  | 30.56 ms  | 30.56 ms  |
|                | 2          | 33.92 ms  | 33.88 ms  | 33.97 ms  | 33.93 ms  | 33.91 ms  | 33.91 ms  |
|                | 4          | 41.06 ms  | 40.83 ms  | 41.31 ms  | 41.23 ms  | 41.07 ms  | 41.07 ms  |
|                | 8          | 54.19 ms  | 54.14 ms  | 54.25 ms  | 54.20 ms  | 54.18 ms  | 54.18 ms  |
|                | 16         | 59.27 ms  | 59.18 ms  | 59.45 ms  | 59.25 ms  | 59.41 ms  | 59.41 ms  |
|                | 32         | 70.56 ms  | 70.50 ms  | 70.62 ms  | 70.56 ms  | 70.62 ms  | 70.62 ms  |
| Decode (total) | 1          | 218.16 ms | 213.71 ms | 256.01 ms | 214.03 ms | 213.94 ms | 213.94 ms |
|                | 2          | 237.45 ms | 237.13 ms | 237.81 ms | 237.50 ms | 237.37 ms | 237.37 ms |
|                | 4          | 287.43 ms | 285.84 ms | 289.14 ms | 288.59 ms | 287.47 ms | 287.47 ms |
|                | 8          | 379.34 ms | 379.00 ms | 379.73 ms | 379.44 ms | 379.29 ms | 379.29 ms |
|                | 16         | 414.88 ms | 414.25 ms | 416.15 ms | 414.76 ms | 415.86 ms | 415.86 ms |
|                | 32         | 493.91 ms | 493.49 ms | 494.36 ms | 493.90 ms | 494.36 ms | 494.36 ms |

| Step    | Batch Size | Average            | Lowest             | Highest            |
|---------|------------|--------------------|--------------------|--------------------|
| Prefill | 1          | 18.42 tokens/secs  | 18.34 tokens/secs  | 18.48 tokens/secs  |
|         | 2          | 33.93 tokens/secs  | 33.85 tokens/secs  | 33.99 tokens/secs  |
|         | 4          | 58.37 tokens/secs  | 58.18 tokens/secs  | 58.65 tokens/secs  |
|         | 8          | 78.10 tokens/secs  | 77.95 tokens/secs  | 78.19 tokens/secs  |
|         | 16         | 111.18 tokens/secs | 111.04 tokens/secs | 111.38 tokens/secs |
|         | 32         | 140.45 tokens/secs | 140.31 tokens/secs | 140.54 tokens/secs |
| Decode  | 1          | 32.18 tokens/secs  | 27.34 tokens/secs  | 32.75 tokens/secs  |
|         | 2          | 58.96 tokens/secs  | 58.87 tokens/secs  | 59.04 tokens/secs  |
|         | 4          | 97.42 tokens/secs  | 96.84 tokens/secs  | 97.96 tokens/secs  |
|         | 8          | 147.62 tokens/secs | 147.47 tokens/secs | 147.76 tokens/secs |
|         | 16         | 269.96 tokens/secs | 269.13 tokens/secs | 270.37 tokens/secs |
|         | 32         | 453.53 tokens/secs | 453.11 tokens/secs | 453.91 tokens/secs |
0x1997 commented 1 year ago

Thanks for @abhinavkulkarni‘s code, I did some simple evaluation. The output quality of AWQ model is a little bit worse than GPTQ. But the 60% speedup at inference is quite nice.

My branch is at https://github.com/0x1997/text-generation-inference/tree/awq.

Currently multi GPU support is broken, the model generates garbled outputs like this. Do you have any idea how to fix this? @abhinavkulkarni

re,:,\\\\.,,,,,,,,,,,,,,,,,,,,,,, a rad, k,,,,,,,,,,,,,,,,,,,e\\\\ the the they, I have,,,,,,,,,, the\\\\ and, I\\\\,,,,,,,,,,,.,\\\\anded the the\\\\\\\\\\\\,,,. the thesers,\\\\\\\\ and, ap\\\\\\\\\\\\\\\\ the, a\\\\ and.\\\\\\\\\\\\\\\\.,,,,,,,,,,,, a I,,,:,\\\\, avision, aon,,,,,,,,,,,, in a a the the ,,,,,,,,,,,, a, a a the the ,,,,,,- the the the the\\\\ made\\,, the k,,,,,,.,,,,,a a,,,,,,, the, a a,her it,,,,,,, a a,,,,,, the it,,, theo and., you., the,
sjzhou4 commented 1 year ago

@abhinavkulkarni https://github.com/huggingface/text-generation-inference/issues/948 This issue is tried by your awq method, can you help to solve this problem? Thank you.

MichaelHauser0971 commented 1 year ago

Thanks for @abhinavkulkarni‘s code, I did some simple evaluation. The output quality of AWQ model is a little bit worse than GPTQ. But the 60% speedup at inference is quite nice.

My branch is at https://github.com/0x1997/text-generation-inference/tree/awq.

Currently multi GPU support is broken, the model generates garbled outputs like this. Do you have any idea how to fix this? @abhinavkulkarni

re,:,\\\\.,,,,,,,,,,,,,,,,,,,,,,, a rad, k,,,,,,,,,,,,,,,,,,,e\\\\ the the they, I have,,,,,,,,,, the\\\\ and, I\\\\,,,,,,,,,,,.,\\\\anded the the\\\\\\\\\\\\,,,. the thesers,\\\\\\\\ and, ap\\\\\\\\\\\\\\\\ the, a\\\\ and.\\\\\\\\\\\\\\\\.,,,,,,,,,,,, a I,,,:,\\\\, avision, aon,,,,,,,,,,,, in a a the the ,,,,,,,,,,,, a, a a the the ,,,,,,- the the the the\\\\ made\\,, the k,,,,,,.,,,,,a a,,,,,,, the, a a,her it,,,,,,, a a,,,,,, the it,,, theo and., you., the,

Have you solved this problem? I also encountered the same problem.

dingjingzhen commented 1 year ago

Can anyone run benchmarks against TGI + exllama kernels ?

Those are supposed to provide a similar speedup.

We don't want to support every quantization scheme in TGI, just the best possible subset:

  • No quantization: best PPL
  • bitsandbytes: Low vram - no quantization steps - works on every model
  • GPTQ: Low VRAM - fasted inference (should be ~2x if I'm not mistaken) with exllama.

This works very good, I measured almost no ppl loss, but the performance is faster than gptq https://github.com/huggingface/text-generation-inference/pull/1018

dingjingzhen commented 1 year ago

Try this solution of ours, the best ppl with faster performance than gptq https://github.com/huggingface/text-generation-inference/pull/1018 Below is our test on 3090. environment: torch=2.01, cuda=11.8, nvidia driver: 525.78.01 prompt=1024, max_new_tokens=50 image

abhinavkulkarni commented 1 year ago

@Narsil: I have opened a PR https://github.com/huggingface/text-generation-inference/pull/1019 for adding AWQ support for FlashLlama models. Please take a look. Please refer to earlier replies from me for benchmarking results against GPTQ.

@MichaelHauser0971, @sjzhou4, @0x1997, @casper-hansen: I have not yet tested multi-GPU setup, lets first try to get approval for single-GPU PR.

ryanshrott commented 1 year ago

@abhinavkulkarni trying to catch up on this thread. How can I run llama2 AWQ or GPTQ with VLLM? Is it possible yet?

casper-hansen commented 1 year ago

@ryanshrott vLLM support for AWQ is close to being merged, check their branch out: https://github.com/vllm-project/vllm/tree/add_awq_quant_support

abhinavkulkarni commented 1 year ago

@ryanshrott: Please check the PR I have raised. It runs a Llama 2 model with FlashAttention v2 and vLLM.

ryanshrott commented 1 year ago

What's the timeline on merging to main branch?

Narsil commented 12 months ago

I ran some tests and the PR is very close to ready. If OP doesn't want to make the changes I'll do them in a few days. Reactivity is a bit lower from us, we have some nice things cooking, please bear with us.

TheBloke commented 12 months ago

Great to hear! I've uploaded plenty of models to be used with it :) image

My READMEs linked to this PR and mentioned support was coming 'soon'. Once this is merged I can update them all to include TGI details.

ryanshrott commented 12 months ago

Will this PR have comparable speeds to regular non-quantized models?

I currently find AWQ quantization with VLLM to run very slow.

casper-hansen commented 12 months ago

INT4 throughput will not be higher than FP16 at very high data parallelism. For that, you must use INT8 or FP16. High batch sizes means that you are compute bound and INT4 is not made for this scenario.

ryanshrott commented 12 months ago

@casper-hansen I'm not quite following all your technical notes. Are you saying that 4-bit awq will be fast or not?

casper-hansen commented 12 months ago

Yes, it can be much faster. But like I just explained, there are cases where it will not be faster. It depends on your use-case.

RonanKMcGovern commented 12 months ago

Will this approach default to using GEMM? or is there a parameters where one can configure GEMV or fp16?

I may be missing something in the code updates, I just didn't find any reference to GEMM.

Also, great work on this.

casper-hansen commented 12 months ago

GEMV is only faster at batch size 1 with a small context (20% faster). For deployment purposes with many concurrent requests, GEMM will overall be much faster as it scales better. @RonanKMcGovern

RonanKMcGovern commented 12 months ago

I pulled this docker image and it's recognising awq.

1.0.3 from the readme won't work for awq though. Might be worth putting a note there in the readme if not ready for a release?

Also, after trying the following flags on the latest image:

--model-id TheBloke/Llama-2-70B-chat-AWQ --trust-remote-code --port 8080 --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 --quantize awq

I'm hitting:

2023-09-25T07:54:05.251889450-07:00     return callback(**use_params)  # type: ignore
2023-09-25T07:54:05.251891440-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 82, in serve
2023-09-25T07:54:05.251893660-07:00     server.serve(
2023-09-25T07:54:05.251895730-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 195, in serve
2023-09-25T07:54:05.251898130-07:00     asyncio.run(
2023-09-25T07:54:05.251900180-07:00   File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
2023-09-25T07:54:05.251902330-07:00     return loop.run_until_complete(main)
2023-09-25T07:54:05.251905440-07:00   File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete
2023-09-25T07:54:05.251907560-07:00     self.run_forever()
2023-09-25T07:54:05.251909610-07:00   File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
2023-09-25T07:54:05.251911670-07:00     self._run_once()
2023-09-25T07:54:05.251913740-07:00   File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
2023-09-25T07:54:05.251915790-07:00     handle._run()
2023-09-25T07:54:05.251917850-07:00   File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run
2023-09-25T07:54:05.251920010-07:00     self._context.run(self._callback, *self._args)
2023-09-25T07:54:05.251924120-07:00 > File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 147, in serve_inner
2023-09-25T07:54:05.251926260-07:00     model = get_model(
2023-09-25T07:54:05.251928480-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 187, in get_model
2023-09-25T07:54:05.251930600-07:00     return FlashLlama(
2023-09-25T07:54:05.251932640-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_llama.py", line 68, in __init__
2023-09-25T07:54:05.251934730-07:00     model = FlashLlamaForCausalLM(config, weights)
2023-09-25T07:54:05.251936780-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 474, in __init__
2023-09-25T07:54:05.251939230-07:00     self.model = FlashLlamaModel(config, weights)
2023-09-25T07:54:05.251941310-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 412, in __init__
2023-09-25T07:54:05.251943390-07:00     [
2023-09-25T07:54:05.251945510-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 413, in <listcomp>
2023-09-25T07:54:05.251947630-07:00     FlashLlamaLayer(
2023-09-25T07:54:05.251949740-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 349, in __init__
2023-09-25T07:54:05.251951740-07:00     self.self_attn = FlashLlamaAttention(
2023-09-25T07:54:05.251953700-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 229, in __init__
2023-09-25T07:54:05.251955840-07:00     self.query_key_value = load_attention(config, prefix, weights)
2023-09-25T07:54:05.251957850-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 154, in load_attention
2023-09-25T07:54:05.251959930-07:00     return _load_gqa(config, prefix, weights)
2023-09-25T07:54:05.251962000-07:00   File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 183, in _load_gqa
2023-09-25T07:54:05.251966830-07:00     weight = weight.to(dtype=weights.dtype).to(device=weights.device)
2023-09-25T07:54:05.251968970-07:00 AttributeError: 'tuple' object has no attribute 'to'
abhinavkulkarni commented 12 months ago

@RonanKMcGovern: I built latest commit on main and was able to run the command you posted, except for model TheBloke/Llama-2-7B-chat-AWQ instead of TheBloke/Llama-2-70B-chat-AWQ.

text-generation-launcher --model-id TheBloke/Llama-2-7B-chat-AWQ --trust-remote-code --port 8080 --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 --quantize awq

I am able to send inputs to the server using cURL and obtain a legible response.

RonanKMcGovern commented 12 months ago

Thanks @abhinavkulkarni !

RonanKMcGovern commented 12 months ago

Confirming 70B is running now as expected. Thanks all.

jjmlovesgit commented 12 months ago

I am also running on 3090 "TheBloke/Llama-2-7B-chat-AWQ" with success on Langchain POCs using commands from above:

Demo mccorji@llama:~/tgi/dev/text-generation-inference$ ./start_7b_awq_simple.sh Starting the Docker container with local files (No Internet): Llama-2-7b-Chat-AWQ and volume: /home/mccorji/tgi/dev/text-generation-inference/data ... cb808d64f57643e7a38e51d04e6ec48d6f9d8f27f8c814039944127fdb8fef20 Container started successfully! Running the text-generation-launcher command from /data directory inside the container...Local Files only will be used 2023-09-27T04:20:17.283981Z INFO text_generation_launcher: Args { model_id: "TheBloke/Llama-2-7B-chat-AWQ", revision: None, validation_workers: 2, sharded: None, num_shard: None, quantize: Some(Awq), dtype: None, trust_remote_code: true, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_length: 2048, max_total_tokens: 4096, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, hostname: "cb808d64f576", port: 8080, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false } 2023-09-27T04:20:17.284020Z WARN text_generation_launcher: trust_remote_code is set. Trusting that model TheBloke/Llama-2-7B-chat-AWQ do not contain malicious code. 2023-09-27T04:20:17.284088Z INFO download: text_generation_launcher: Starting download process. 2023-09-27T04:20:19.235606Z INFO text_generation_launcher: Files are already present on the host. Skipping download.

2023-09-27T04:20:19.487792Z INFO download: text_generation_launcher: Successfully downloaded weights. 2023-09-27T04:20:19.488056Z INFO shard-manager: text_generation_launcher: Starting shard rank=0 2023-09-27T04:20:28.267572Z INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0

2023-09-27T04:20:28.299314Z INFO shard-manager: text_generation_launcher: Shard ready in 8.810838819s rank=0 2023-09-27T04:20:28.399514Z INFO text_generation_launcher: Starting Webserver 2023-09-27T04:20:28.719311Z WARN text_generation_router: router/src/main.rs:349: --revision is not set 2023-09-27T04:20:28.719339Z WARN text_generation_router: router/src/main.rs:350: We strongly advise to set it to a known supported commit. 2023-09-27T04:20:28.993330Z INFO text_generation_router: router/src/main.rs:371: Serving revision 47c8d2736daf1e3b57d9689129c3ddfc596299e1 of model TheBloke/Llama-2-7b-Chat-AWQ 2023-09-27T04:20:28.998950Z INFO text_generation_router: router/src/main.rs:213: Warming up model 2023-09-27T04:20:31.358658Z INFO text_generation_router: router/src/main.rs:246: Setting max batch total tokens to 31984 2023-09-27T04:20:31.358683Z INFO text_generation_router: router/src/main.rs:247: Connected 2023-09-27T04:20:31.358687Z WARN text_generation_router: router/src/main.rs:252: Invalid hostname, defaulting to 0.0.0.0 2023-09-27T04:20:58.063416Z INFO HTTP request{otel.name=POST / http.client_ip= http.flavor=1.1 http.host=localhost:8080 http.method=POST http.route=/ http.scheme=HTTP http.target=/ http.user_agent=python-requests/2.29.0 otel.kind=server trace_id=97f2b23363738b637d455608234b8cc9 http.status_code=200 otel.status_code="OK"}:compat_generate{default_return_full_text=true}:generate_stream{parameters=GenerateParameters { best_of: None, temperature: Some(0.01), repetition_penalty: Some(1.03), top_k: Some(10), top_p: Some(0.95), typical_p: Some(0.95), do_sample: false, max_new_tokens: 512, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None, top_n_tokens: None } total_time="1.92377586s" validation_time="1.458384ms" queue_time="75.402µs" inference_time="1.922242217s" time_per_token="14.030965ms" seed="Some(3854729438860470207)"}: text_generation_router::server: router/src/server.rs:457: Success 2023-09-27T04:21:24.539278Z INFO HTTP request{otel.name=POST / http.client_ip= http.flavor=1.1 http.host=localhost:8080 http.method=POST http.route=/ http.scheme=HTTP http.target=/ http.user_agent=python-requests/2.29.0 otel.kind=server trace_id=f9923de7bbb6fe7f662d34a2d3566b00 http.status_code=200 otel.status_code="OK"}:compat_generate{default_return_full_text=true}:generate_stream{parameters=GenerateParameters { best_of: None, temperature: Some(0.01), repetition_penalty: Some(1.03), top_k: Some(10), top_p: Some(0.95), typical_p: Some(0.95), do_sample: false, max_new_tokens: 512, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None, top_n_tokens: None } total_time="361.679818ms" validation_time="1.190807ms" queue_time="58.746µs" inference_time="360.430359ms" time_per_token="20.023908ms" seed="Some(6292320286675141931)"}: text_generation_router::server: router/src/server.rs:457: Success

naticio commented 11 months ago

so.....can we run any awq model using TGI or just some of them (as of now)

trying to launch text-generation-launcher --model-id TheBloke/Wizard-Vicuna-30B-Uncensored-AWQ

but doesn;t work

RuntimeError: weight model.layers.0.self_attn.q_proj.weight does not exist rank=0 2023-10-14T00:15:44.698002Z ERROR text_generation_launcher: Shard 0 failed to start 2023-10-14T00:15:44.698035Z INFO text_generation_launcher: Shutting down shards

abhinavkulkarni commented 11 months ago

@naticio: Currently only FlashLlama models are supported for AWQ quantization. So, the underlying model has to be a Llama 1 or 2 architecture.

However, it should be easy to add support for other types of AWQ quantized models such as MPT, Falcon, etc.