microsoft / onnxruntime-genai

Generative AI extensions for onnxruntime
MIT License
532 stars 132 forks source link

Phi-3 128k `cos_cache dimension 0 should be of max_sequence_length.` when setting larger context window #383

Closed Ben-Epstein closed 6 months ago

Ben-Epstein commented 7 months ago

I've been testing out phi3-128k, but am running into issues using larger context windows (>4000)

With cuda-fp16, anything larger than 4096 gives me a memory allocation error, which is surprising because I can run 8k models easily.

And when I use cuda-int4-rtn-block-32, I get this error

OrtException: Non-zero status code returned while running GroupQueryAttention node. Name:'/model/layers.0/attn/GroupQueryAttention' Status Message: cos_cache dimension 0 should be of max_sequence_length.

Here's some code I used

import onnxruntime_genai as og

chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
# model_name = "Phi-3-mini-128k-instruct-onnx/cuda/cuda-fp16"  # GPU VRAM only avl for 4000 tokens
model_name = "Phi-3-mini-128k-instruct-onnx/cuda/cuda-int4-rtn-block-32"
model = og.Model(model_name)
tokenizer = og.Tokenizer(model)

and then

def predict(text: str, max_length: int = 4096) -> str:
    # text = "How are you doing today?"
    tokenizer_stream = tokenizer.create_stream()
    search_options = {"do_sample": True, "temperature": 0.0, "top_p": 1, "max_length": max_length}
    prompt = f'{chat_template.format(input=text)}'

    input_tokens = tokenizer.encode(prompt)

    params = og.GeneratorParams(model)
    params.try_use_cuda_graph_with_max_batch_size(1)
    params.set_search_options(**search_options)
    params.input_ids = input_tokens

    generator = og.Generator(model, params)
    output = []
    while not generator.is_done():
        generator.compute_logits()
        generator.generate_next_token()
        new_token = generator.get_next_tokens()[0]
        decoded = tokenizer_stream.decode(new_token)
        print(decoded, end='', flush=True)
        output.append(decoded)
    return "".join(output)

res = predict("How are you doing today?", 8192)

Can you help me understand what that error is? Here is my nvidia config, I have a T4

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P0    28W /  70W |  12891MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

Thanks!

kunal-vaishnavi commented 7 months ago

With cuda-fp16, anything larger than 4096 gives me a memory allocation error, which is surprising because I can run 8k models easily.

There are a couple of options to avoid the memory allocation error.

And when I use cuda-int4-rtn-block-32, I get this error

OrtException: Non-zero status code returned while running GroupQueryAttention node. > Name:'/model/layers.0/attn/GroupQueryAttention' Status Message: cos_cache dimension 0 should be of max_sequence_length.

This error can happen with the Phi-3 mini 128K model in the following scenario.

There are a couple of options to avoid this error.

Ben-Epstein commented 7 months ago

Prompt length is less than 4K / Prompt length + generation length is greater than 4K

Why would these cause an issue? Just trying to get a better understanding

kunal-vaishnavi commented 7 months ago

It is caused by a check in ONNX Runtime that previously required the first dimension of the cos/sin caches in the rotary embeddings to be the same size as the third dimension of the KV caches.

When buffer sharing is enabled, the third dimension of the KV caches is set to 131072 by default since that is the default value for max_length. Depending on the prompt length passed to the model, the first dimension of the cos/sin caches can either be 4096 or 131072. When the prompt length is less than 4096, the cos/sin cache that is selected is the one where the first dimension is of size 4096. The check will see that cos_dims[0] = 4096, present_sequence_length = 131072, and cos_dims[0] < present_sequence_length is true so the error will get raised.

Because present_sequence_length just has to be greater than 4096 for this error to occur, this means that for a prompt length less than 4096, any value for max_length that is greater than 4096 will cause this error. Since this error isn't really a limitation of the max length but rather how the KV caches are pre-allocated, I described this as prompt length + generation length is greater than 4K instead.

The check has been fixed in this PR. If you build ONNX Runtime from source and then build ONNX Runtime GenAI from source, the above errors should go away and you should be able to use past_present_share_buffer = true.

Ben-Epstein commented 6 months ago

@kunal-vaishnavi thanks so much, this worked perfectly. For now i'll run with setting past_present_share_buffer=False until the next release :)