turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.28k stars 243 forks source link

Generating a batch of different propmpt sizes, the shorter prompts tend to suffer #200

Closed ziadloo closed 1 month ago

ziadloo commented 7 months ago

When I group more than one prompt into a batch, if the prompts are of different sizes, the generated output for the shorter prompts suffers from repeating tokens. This does not happen when all the prompts are the same.

Here's an example (the prompts are taken from HumanEval):

from exllamav2 import(
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2Sampler
)

batch_size = 4
model_directory = './models/TheBloke_Phind-CodeLlama-34B-v2-GPTQ/'

config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
config.scale_pos_emb = 1
config.scale_alpha_value = 1
config.max_seq_len = 2048

model = ExLlamaV2(config)

cache = ExLlamaV2Cache(model, batch_size = batch_size, lazy = True)
model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 1.0
settings.top_k = 500
settings.top_p = 0.3
settings.typical = 1
max_new_tokens = 512

prompts = [
    "from typing import List\n\n\ndef string_xor(a: str, b: str) -> str:\n    \"\"\" Input are two strings a and b consisting only of 1s and 0s.\n    Perform binary XOR on these inputs and return result also as a string.\n    >>> string_xor('010', '110')\n    '100'\n    \"\"\"\n",
    "\n\ndef truncate_number(number: float) -> float:\n    \"\"\" Given a positive floating point number, it can be decomposed into\n    and integer part (largest integer smaller than given number) and decimals\n    (leftover part always smaller than 1).\n\n    Return the decimal part of the number.\n    >>> truncate_number(3.5)\n    0.5\n    \"\"\"\n",
    "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    \"\"\"\n",
    "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n    \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n    separate those group into separate strings and return the list of those.\n    Separate groups are balanced (each open brace is properly closed) and not nested within each other\n    Ignore any spaces in the input string.\n    >>> separate_paren_groups('( ) (( )) (( )( ))')\n    ['()', '(())', '(()())']\n    \"\"\"\n",
]
cache.current_seq_len = 0
output = generator.generate_simple(prompts, settings, max_new_tokens)
print("Batch of different prompt sizes:")
for i in range(len(output)):
    print(f"{i}. {output[i]}")

prompts0 = [
    prompts[0]
    for i in range(batch_size)
]
cache.current_seq_len = 0
output = generator.generate_simple(prompts0, settings, max_new_tokens)
print("Batch of same prompt (0) sizes:")
for i in range(len(output)):
    print(f"{i}. {output[i]}")

prompts1 = [
    prompts[1]
    for i in range(batch_size)
]
cache.current_seq_len = 0
output = generator.generate_simple(prompts1, settings, max_new_tokens)
print("Batch of same prompt (1) sizes:")
for i in range(len(output)):
    print(f"{i}. {output[i]}")

prompts2 = [
    prompts[2]
    for i in range(batch_size)
]
cache.current_seq_len = 0
output = generator.generate_simple(prompts2, settings, max_new_tokens)
print("Batch of same prompt (2) sizes:")
for i in range(len(output)):
    print(f"{i}. {output[i]}")

prompts3 = [
    prompts[3]
    for i in range(batch_size)
]
cache.current_seq_len = 0
output = generator.generate_simple(prompts3, settings, max_new_tokens)
print("Batch of same prompt (3) sizes:")
for i in range(len(output)):
    print(f"{i}. {output[i]}")

In this example, first we generate 4 outputs for a batch of prompts with different sizes. Then the same prompts are each copied to a batch of their own. This behaviour is replicable but I could not find the reasoning why/how it happens other than the fact that it has something to do with the length of the input prompts.

Batch of different prompt sizes:

0. from typing import List

def string_xor(a: str, b: str) -> str:
    """ Input are two strings a and b consisting only of 1s and 0s.
    Perform binary XOR on these inputs and return result also as a string.
    >>> string_xor('010', '110')
    '100'
    """
=20
1

(a)

        . _1/

-) (
)
.
)
}
1. 

def truncate_number(number: float) -> float:
    """ Given a positive floating point number, it can be decomposed into
    and integer part (largest integer smaller than given number) and decimals
    (leftover part always smaller than 1).

    Return the decimal part of the number.
    >>> truncate_number(3.5)
    0.5
    """

    return the
    }
   )
}
   = "0";
   >;
   = "C";
   }
)
   = "0;">;
   = "C;"
   = "C>"
   = "C"
   = "C"
   = ");
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C"
   = "C
2. from typing import List

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    for i in range(len(numbers)):
        for j in range(i + 1, len(numbers)):
            if abs(numbers[i] - numbers[j]) < threshold:
                return True
    return False
3. from typing import List

def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    separate those group into separate strings and return the list of those.
    Separate groups are balanced (each open brace is properly closed) and not nested within each other
    Ignore any spaces in the input string.
    >>> separate_paren_groups('( ) (( )) (( )( ))')
    ['()', '(())', '(()())']
    """
    paren_string = paren_string.replace(" ", "")
    result, count, start = [], 0, 0
    for i, char in enumerate(paren_string):
        if char == "(":
            if count == 0:
                start = i
            count += 1
        elif char == ")":
            count -= 1
            if count == 0:
                result.append(paren_string[start : i + 1])
    return result

Batch of same prompt (0) sizes:

0. from typing import List

def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    separate those group into separate strings and return the list of those.
    Separate groups are balanced (each open brace is properly closed) and not nested within each other
    Ignore any spaces in the input string.
    >>> separate_paren_groups('( ) (( )) (( )( ))')
    ['()', '(())', '(()())']
    """
    paren_string = paren_string.replace(" ", "")
    result, count, start = [], 0, 0
    for i, char in enumerate(paren_string):
        if char == "(":
            if count == 0:
                start = i
            count += 1
        elif char == ")":
            count -= 1
            if count == 0:
                result.append(paren_string[start : i + 1])
    return result
1. from typing import List

def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    separate those group into separate strings and return the list of those.
    Separate groups are balanced (each open brace is properly closed) and not nested within each other
    Ignore any spaces in the input string.
    >>> separate_paren_groups('( ) (( )) (( )( ))')
    ['()', '(())', '(()())']
    """
    paren_string = paren_string.replace(" ", "")
    result, count, start = [], 0, 0
    for i, char in enumerate(paren_string):
        if char == "(":
            if count == 0:
                start = i
            count += 1
        elif char == ")":
            count -= 1
            if count == 0:
                result.append(paren_string[start : i + 1])
    return result
2. from typing import List

def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    separate those group into separate strings and return the list of those.
    Separate groups are balanced (each open brace is properly closed) and not nested within each other
    Ignore any spaces in the input string.
    >>> separate_paren_groups('( ) (( )) (( )( ))')
    ['()', '(())', '(()())']
    """
    paren_string = paren_string.replace(" ", "")
    result, count, start = [], 0, 0
    for i, char in enumerate(paren_string):
        if char == "(":
            if count == 0:
                start = i
            count += 1
        elif char == ")":
            count -= 1
            if count == 0:
                result.append(paren_string[start : i + 1])
    return result
3. from typing import List

def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    separate those group into separate strings and return the list of those.
    Separate groups are balanced (each open brace is properly closed) and not nested within each other
    Ignore any spaces in the input string.
    >>> separate_paren_groups('( ) (( )) (( )( ))')
    ['()', '(())', '(()())']
    """
    paren_string = paren_string.replace(" ", "")
    result, count, start = [], 0, 0
    for i, char in enumerate(paren_string):
        if char == "(":
            if count == 0:
                start = i
            count += 1
        elif char == ")":
            count -= 1
            if count == 0:
                result.append(paren_string[start : i + 1])
    return result

Batch of same prompt (1) sizes:

0. 

def truncate_number(number: float) -> float:
    """ Given a positive floating point number, it can be decomposed into
    and integer part (largest integer smaller than given number) and decimals
    (leftover part always smaller than 1).

    Return the decimal part of the number.
    >>> truncate_number(3.5)
    0.5
    """
    return number - int(number)
1. 

def truncate_number(number: float) -> float:
    """ Given a positive floating point number, it can be decomposed into
    and integer part (largest integer smaller than given number) and decimals
    (leftover part always smaller than 1).

    Return the decimal part of the number.
    >>> truncate_number(3.5)
    0.5
    """
    return number - int(number)
2. 

def truncate_number(number: float) -> float:
    """ Given a positive floating point number, it can be decomposed into
    and integer part (largest integer smaller than given number) and decimals
    (leftover part always smaller than 1).

    Return the decimal part of the number.
    >>> truncate_number(3.5)
    0.5
    """
    return number - int(number)
3. 

def truncate_number(number: float) -> float:
    """ Given a positive floating point number, it can be decomposed into
    and integer part (largest integer smaller than given number) and decimals
    (leftover part always smaller than 1).

    Return the decimal part of the number.
    >>> truncate_number(3.5)
    0.5
    """
    return number - int(number)

Batch of same prompt (2) sizes:

0. from typing import List

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    for i in range(len(numbers)):
        for j in range(i + 1, len(numbers)):
            if abs(numbers[i] - numbers[j]) < threshold:
                return True
    return False
1. from typing import List

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    for i in range(len(numbers)):
        for j in range(i + 1, len(numbers)):
            if abs(numbers[i] - numbers[j]) < threshold:
                return True
    return False
2. from typing import List

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    for i in range(len(numbers)):
        for j in range(i + 1, len(numbers)):
            if abs(numbers[i] - numbers[j]) < threshold:
                return True
    return False
3. from typing import List

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    for i in range(len(numbers)):
        for j in range(i + 1, len(numbers)):
            if abs(numbers[i] - numbers[j]) < threshold:
                return True
    return False

Batch of same prompt (3) sizes:

0. from typing import List

def string_xor(a: str, b: str) -> str:
    """ Input are two strings a and b consisting only of 1s and 0s.
    Perform binary XOR on these inputs and return result also as a string.
    >>> string_xor('010', '110')
    '100'
    """
    assert len(a) == len(b), "Input strings must be equal length"
    xor = ""
    for i in range(len(a)):
        if a[i] != b[i]:
            xor += "1"
        else:
            xor += "0"
    return xor

def bitwise_xor_list(numbers: List[int]) -> int:
    """ Apply a bitwise XOR operation to all numbers in the list and return the result.
    >>> bitwise_xor_list([5, 7, 8])
    3
    """
    result = numbers[0]
    for num in numbers[1:]:
        result ^= num
    return result
1. from typing import List

def string_xor(a: str, b: str) -> str:
    """ Input are two strings a and b consisting only of 1s and 0s.
    Perform binary XOR on these inputs and return result also as a string.
    >>> string_xor('010', '110')
    '100'
    """
    assert len(a) == len(b), "Input strings must be equal length"
    xor = ""
    for i in range(len(a)):
        if a[i] != b[i]:
            xor += "1"
        else:
            xor += "0"
    return xor

def bitwise_xor_list(a: List[int], b: List[int]) -> List[int]:
    """ Apply the XOR operation to each pair of elements from input lists `a` and `b`.
    The function returns a list with the results.
    >>> bitwise_xor_list([1,2,3],[4,5,6])
    [7, 7, 5]
    """
    return [x ^ y for x, y in zip(a, b)]
2. from typing import List

def string_xor(a: str, b: str) -> str:
    """ Input are two strings a and b consisting only of 1s and 0s.
    Perform binary XOR on these inputs and return result also as a string.
    >>> string_xor('010', '110')
    '100'
    """
    assert len(a) == len(b), "Input strings must be equal length"
    xor = ""
    for i in range(len(a)):
        if a[i] != b[i]:
            xor += "1"
        else:
            xor += "0"
    return xor

def bitwise_xor_list(a: List[int], b: List[int]) -> List[int]:
    """ Apply the XOR operation to each pair of elements from input lists `a` and `b`.
    The function returns a list with the results.
    >>> bitwise_xor_list([1,2,3],[4,5,6])
    [7, 7, 5]
    """
    return [x ^ y for x, y in zip(a, b)]
3. from typing import List

def string_xor(a: str, b: str) -> str:
    """ Input are two strings a and b consisting only of 1s and 0s.
    Perform binary XOR on these inputs and return result also as a string.
    >>> string_xor('010', '110')
    '100'
    """
    assert len(a) == len(b), "Input strings must be equal length"
    xor = ""
    for i in range(len(a)):
        if a[i] != b[i]:
            xor += "1"
        else:
            xor += "0"
    return xor

def bitwise_xor_list(a: List[int], b: List[int]) -> List[int]:
    """ Apply the XOR operation to each pair of elements from input lists `a` and `b`.
    The function returns a list with the results.
    >>> bitwise_xor_list([1,2,3],[4,5,6])
    [7, 7, 5]
    """
    return [x ^ y for x, y in zip(a, b)]
turboderp commented 7 months ago

This could be a result of right-aligning and padding the inputs in the batch. The rotary embeddings shouldn't care what the starting token ID is as long as any prior tokens are masked out during attention, and I haven't had issues with it before, even mixing very long and very short prompts in a batch. CodeLlama does use a very large rotary embedding base (1e6), so that could be amplifying any numerical precision issues.

I'll have a look and see if maybe the padding approach isn't enough and there needs to be another channel for communicating position offsets to the forward pass. In the meantime, I guess to rule out that it's a simple tokenization issue, you could verify that the attention mask looks reasonable in generate_simple():

mask = self.tokenizer.padding_mask(ids) if batch_size > 1 else None
ziadloo commented 7 months ago

Thanks for the quick reply. I checked and I can confirm that the mask is generated right-aligned and properly. Looking forward for a fix. Perhaps for now, I'll batch copies of the same prompt together. I was going for pass@10 anyways.

Thanks again, this code is very stable. Unlike other packages I worked with.

turboderp commented 7 months ago

Well, I pushed an update that shifts the RoPE position IDs according to the length of each item in the batch. This at least ensures that the first token in a sequence will be encoded as position zero regardless of how many padding tokens precede it. It doesn't fix the issue, though, which comes down to the fact that Flash Attention doesn't support attention masking. So all those padding tokens are still attended to, and you simply get wrong output when they're present.

You can fix it for now with config.no_flash_attn = True, and I seem to be getting the expected output in your example using that option.

While flash-attn does provide ways to process variable-length sequences in a batch, rewriting the code to use those would make flash-attn a requirement rather than an option, and it's still problematic for Windows and AMD users. So it could take some time to come up with a proper solution.

ziadloo commented 7 months ago

On the last point in your message, I'm not sure if I follow, perhaps I'm just misunderstood. But isn't it possible to have both implementations (the current one and the fix) in code available and switch between the two based on the value of config.no_flash_attn? Like, a completely new class that implements the fix and when you are instantiating an object, you'll choose the class based on whether the user has set the no_flash_attn or not. That way, by resetting this variable, the user can switch back to the current implementation if they cannot run a flash attention.

Perhaps, I'm missing something and it's not that simple. In any case, thanks for the reply.

turboderp commented 7 months ago

The fix I pushed today doesn't really relate to flash-attn. The fix is just always on, shifting position IDs whenever you generate in a batch.

But regardless of that, flash-attn won't currently work with batches of varying sequence length because of the padding mask issue. So you can disable it at runtime with the no_flash_attn flag, and then batching will work correctly.

turboderp commented 7 months ago

Turns out there was a bug in the how the position embeddings were being applied. I'm getting consistent results with the latest commit, although still only with flash-attn disabled.

Now looking at the PR #240 which looks promising as a solution for flash-attn.

turboderp commented 1 month ago

This should be fully addressed with the dynamic generator which does not require attention masking and starts all sequences in a batch at position zero, regardless of uneven lengths.