casper-hansen / AutoAWQ

AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference. Documentation:
https://casper-hansen.github.io/AutoAWQ/
MIT License
1.76k stars 211 forks source link

Long delay after first token generation in 0.1.8 that's not in 0.1.7. Also 0.1.8 is much slower than 0.1.7 #314

Closed pseudotensor closed 9 months ago

pseudotensor commented 9 months ago

Reported by user of h2oGPT: https://github.com/h2oai/h2ogpt/issues/1309

I used an edited version of the text streamer, only changed by printing every token instead of waiting for a space. You'll see it print Wh, then wait nearly 0.5 seconds, then continue.

This only occurs with 0.1.8

https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp310-cp310-linux_x86_64.whl

and not in 0.1.7:

https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp310-cp310-linux_x86_64.whl

Also, the script for the 2+ generations runs in 3.6 seconds with 0.1.8 and in 2.5 seconds in 0.1.7.

So there are 2 problems, but perhaps the delay accounts for all of the difference.

Script, ran on 4A10G or 4A6000:

prompt = """Pay attention and remember the information below, which will help to answer the question or imperative after the context ends.
\"\"\"
103.9
60.6
74.6
126.0
109.6
64.3
Whisper small
16.4
87.3
103.6
14.7
92.9
7.3
29.8
11.4
131.7
33.3
49.3
140.0
105.3
42.2
Whisper medium
9.9
79.5
102.0
8.0
119.4
5.0
20.0
7.2
147.0
17.3
31.9
143.9
104.0
44.9
Whisper large
8.3
75.9
102.8
7.2
92.7
4.8
15.4
6.4
177.9
15.7
27.8
130.0
103.5
29.2
Model
Swedish
Swahili
Tamil
Telugu
Tajik
Thai
Turkish
Ukrainian
Urdu
Uzbek
Vietnamese
Yoruba
Whisper tiny
52.7
100.9
99.9
105.1
101.7
58.8
42.5
51.2
65.2
105.2
60.0
106.4
Whisper base
37.4
92.5
58.7
105.2
109.3
38.2
27.5

Table 18. Whisper model learning rates.

70.3
104.4
100.4
19.6
100.1
Whisper medium
19.3
24.3
60.1
10.2
49.9
5.2
7.1
67.9
117.3
48.8
98.9
77.7
16.4
90.0
Whisper large
16.7
21.0
53.7
8.5
43.0
4.2
6.4
87.0
100.5
43.8
96.0
69.8
15.2
86.5
Model
Lingala
Lao
Lithuanian
Latvian
Maori
Macedonian
Malayalam
Mongolian
Marathi
Malay
Maltese
Myanmar
Norwegian
Nepali
Whisper tiny
105.4
115.1
98.5
91.6
94.5
73.3
101.5
113.7
100.3
51.2
100.8
124.8
62.0
101.8
Whisper base
96.7
105.1
87.3
79.8
77.5
59.9
107.4
125.7
100.3
35.1
97.6
122.6
44.0
102.4
Whisper small

We?d like to thank the millions of people who were involved
in creating the data used by Whisper. We?d also like to
thank Nick Ryder, Will Zhuk, and Andrew Carr for the
conversation on the waterfall hike that inspired this project.
We are also grateful to the Acceleration and Supercomputing
teams at OpenAI for their critical work on software and
hardware infrastructure this project used. We?d also like to
thank Pamela Mishkin for advising the project from a policy

the Whisper model under additive pub noise of SNR below
10 dB. This showcases Whisper?s robustness to noise, es-
pecially under more natural distribution shifts like the pub
noise.
3.8. Long-form Transcription
Whisper models are trained on 30-second audio chunks and
cannot consume longer audio inputs at once. This is not a
problem with most academic datasets comprised of short
utterances but presents challenges in real-world applications
which often require transcribing minutes- or hours-long au-

Whisper large
34.3
21.7
77.8
22.8
15.9
17.6
20.6
22.7
31.6
26.0
14.8
0.5
19.6
20.7
Model
Croatian
Hungarian
Armenian
Indonesian
Icelandic
Italian
Japanese
Javanese
Georgian
Kazakh
Khmer
Kannada
Korean
Luxembourgish
Whisper tiny
0.6
0.1
0.1
0.3
0.4
5.3
0.2
0.2
0.1
0.1
0.1
0.8
0.5
0.8
Whisper base
3.7
0.2
0.1
2.6
0.4
11.3
1.5
0.2
0.2
0.2
0.1
0.9
3.7
1.7
Whisper small
14.6
4.8
0.7
16.4
1.8
17.8
9.6
1.4
0.2
0.8
0.5
2.3
12.2
5.7
Whisper medium
23.0
15.5
10.4
24.1
6.8
21.6
14.9
5.0
1.3
4.3
3.3
8.5
19.2
13.6

Robust Speech Recognition via Large-Scale Weak Supervision
5
Model
Layers
Width
Heads
Parameters
Tiny
4
384
6
39M
Base
6
512
8
74M
Small
12
768
12
244M
Medium
24
1024
16
769M
Large
32
1280
20
1550M
Table 1. Architecture details of the Whisper model family.
3. Experiments
3.1. Zero-shot Evaluation
The goal of Whisper is to develop a single robust speech
processing system that works reliably without the need for
dataset specific fine-tuning to achieve high-quality results

Whisper tiny.en
5.5
12.8
13.8
15.1
17.0
22.0
30.3
Whisper tiny
6.8
15.5
16.7
17.0
18.7
24.4
33.1
Whisper base.en
4.6
9.4
11.2
13.2
12.5
16.6
25.2
Whisper base
4.8
12.2
12.2
14.5
13.5
18.4
26.9
Whisper small.en
4.6
6.0
9.4
12.0
10.8
14.0
21.9
Whisper small
4.2
6.9
10.1
12.1
11.1
14.3
22.3
Whisper medium.en
3.6
5.2
8.9
11.9
10.2
13.3
20.6
Whisper medium
3.8
5.4
8.6
11.4
10.3
13.2
20.3
Whisper large
3.8
5.3
8.8
11.0
10.3
13.4
20.4
wav2vec2-base-100h
17.6
27.7
39.3
35.2
45.7
57.1
55.4
wav2vec2-base-960h
12.8

Whisper large
25.4
18.3
13.2
27.2
6.6
23.5
17.0
5.1
2.7
6.3
5.2
9.9
20.0
15.4
Model
Lingala
Lao
Lithuanian
Latvian
Maori
Macedonian
Malayalam
Mongolian
Marathi
Malay
Maltese
Myanmar
Norwegian
Nepali
Whisper tiny
0.1
0.2
0.1
0.2
0.3
1.0
0.8
0.1
0.2
0.3
0.6
0.1
1.4
0.1
Whisper base
0.1
0.3
0.3
0.4
1.0
5.4
1.4
0.1
0.9
2.1
1.4
0.1
8.4
0.3
Whisper small
0.5
2.0
1.9
1.5
3.9
15.3
5.7
0.1
3.8
14.1
4.9
0.0
22.0
2.9
Whisper medium
0.9
8.1
9.6
10.0
8.5
23.5
13.8
0.5
10.9
23.2
11.2
0.2
29.1
12.7
Whisper large
1.2
9.3

the following URL: https://github.com/openai/
whisper.
2. Approach
2.1. Data Processing
Following the trend of recent work leveraging web-scale
text from the internet for training machine learning systems,
we take a minimalist approach to data pre-processing. In
contrast to a lot of work on speech recognition, we train
Whisper models to predict the raw text of transcripts without
any significant standardization, relying on the expressive-
ness of sequence-to-sequence models to learn to map be-
\"\"\"
According to only the information in the document sources provided within the context above: 
What is Whisper?
"""

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from transformers.generation.streamers import BaseStreamer

class TextStreamer(BaseStreamer):
    """
    Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.

    <Tip warning={true}>

    The API for the streamer classes is still under development and may change in the future.

    </Tip>

    Parameters:
        tokenizer (`AutoTokenizer`):
            The tokenized used to decode the tokens.
        skip_prompt (`bool`, *optional*, defaults to `False`):
            Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
        decode_kwargs (`dict`, *optional*):
            Additional keyword arguments to pass to the tokenizer's `decode` method.

    Examples:

        ```python
        >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

        >>> tok = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
        >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
        >>> streamer = TextStreamer(tok)

        >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
        >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
        An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
"""

def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
    self.tokenizer = tokenizer
    self.skip_prompt = skip_prompt
    self.decode_kwargs = decode_kwargs

    # variables used in the streaming process
    self.token_cache = []
    self.print_len = 0
    self.next_tokens_are_prompt = True

def put(self, value):
    """
    Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
    """
    if len(value.shape) > 1 and value.shape[0] > 1:
        raise ValueError("TextStreamer only supports batch size 1")
    elif len(value.shape) > 1:
        value = value[0]

    if self.skip_prompt and self.next_tokens_are_prompt:
        self.next_tokens_are_prompt = False
        return

    # Add the new token to the cache and decodes the entire thing.
    self.token_cache.extend(value.tolist())
    text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)

    # After the symbol for a new line, we flush the cache.
    if text.endswith("\n"):
        printable_text = text[self.print_len :]
        self.token_cache = []
        self.print_len = 0
    # If the last token is a CJK character, we print the characters.
    elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
        printable_text = text[self.print_len :]
        self.print_len += len(printable_text)
    # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
    # which may change with the subsequent token -- there are probably smarter ways to do this!)
    else:
        # printable_text = text[self.print_len : text.rfind(" ") + 1]
        printable_text = text[self.print_len:]
        self.print_len += len(printable_text)
        print("Token: %s" % printable_text, flush=True)

    self.on_finalized_text(printable_text)

def end(self):
    """Flushes any remaining cache and prints a newline to stdout."""
    # Flush the cache, if it exists
    if len(self.token_cache) > 0:
        text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
        printable_text = text[self.print_len :]
        self.token_cache = []
        self.print_len = 0
    else:
        printable_text = ""

    self.next_tokens_are_prompt = True
    self.on_finalized_text(printable_text, stream_end=True)

def on_finalized_text(self, text: str, stream_end: bool = False):
    """Prints the new text to stdout. If the stream is ending, also prints a newline."""
    print(text, flush=True, end="" if not stream_end else None)

def _is_chinese_char(self, cp):
    """Checks whether CP is the codepoint of a CJK character."""
    # This defines a "chinese character" as anything in the CJK Unicode block:
    #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    #
    # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
    # despite its name. The modern Korean Hangul alphabet is a different block,
    # as is Japanese Hiragana and Katakana. Those alphabets are used to write
    # space-separated words, so they are not treated specially and handled
    # like the all of the other languages.
    if (
        (cp >= 0x4E00 and cp <= 0x9FFF)
        or (cp >= 0x3400 and cp <= 0x4DBF)  #
        or (cp >= 0x20000 and cp <= 0x2A6DF)  #
        or (cp >= 0x2A700 and cp <= 0x2B73F)  #
        or (cp >= 0x2B740 and cp <= 0x2B81F)  #
        or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
        or (cp >= 0xF900 and cp <= 0xFAFF)
        or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
    ):  #
        return True

    return False

quant_path = "TheBloke/openchat_3.5-16k-AWQ"

Load model

model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

prompt_template = """GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:"""

tokens = tokenizer( prompt_template.format(prompt=prompt), return_tensors='pt' ).input_ids.cuda()

Generate output

import time t0 = time.time() generation_output = model.generate( tokens, streamer=streamer, max_new_tokens=256, ) print("duration: %s" % (time.time() - t0), flush=True)

import time t0 = time.time() generation_output = model.generate( tokens, streamer=streamer, max_new_tokens=256, ) print("duration: %s" % (time.time() - t0), flush=True)

time.sleep(100)

casper-hansen commented 9 months ago

I understand why users might report this as an issue. Previously, we hard-coded rope_theta=10000 in 0.1.7 but in 0.1.8, we read it from the config. The model linked just so happens to have rope_theta=1000000 which might be a bit slower.

pseudotensor commented 9 months ago

So you can explain the ~1s delay between the first token and the 2+ tokens? It's not just slower, that's just one issue. The other one was the large lag between the first token and all other tokens.

casper-hansen commented 9 months ago

I am not able to explain it without further checking out between commits, which I won't have time to do right now. A potential issue is the change in how we handle position ids. Tough to say without thorough testing.

image

pseudotensor commented 9 months ago

Understood, I haven't seen such a massive lag before in other tools, and 0.1.7 doesn't do it. I'll recommend users use 0.1.7 if they are concerned. Thanks!

casper-hansen commented 9 months ago

Closing this issue as not planned for now, I am not having an easy time reproducing the suggested delay that was introduced.