abetlen / llama-cpp-python

Python bindings for llama.cpp
https://llama-cpp-python.readthedocs.io
MIT License
8.1k stars 964 forks source link

`Llama.embed` crashes when `n_batch` > 512 #1762

Open lsorber opened 1 month ago

lsorber commented 1 month ago

Expected Behavior

Embedding text with a long-context model like BGE-M3 [1] should be able to output token embeddings for more than 512 tokens (this is of interest for 'late interaction' retrieval [2]).

Llama-cpp-python will truncate the input tokens to the first n_batch tokens, where n_batch is 512 by default. The expected behaviour is that setting n_batch to a larger value would allow computing the token embeddings for longer sequences.

[1] https://huggingface.co/BAAI/bge-m3 [2] https://jina.ai/news/what-is-colbert-and-late-interaction-and-why-they-matter-in-search/

Current Behavior

The kernel crashes when embedding text with n_batch > 512. This crash is not specific to the embedding model, for a few models I've tried.

Steps to Reproduce

On a Google Colab T4 instance:

%pip install --quiet --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/12.2 llama-cpp-python==0.3.0

from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama

embedder = Llama.from_pretrained(
    repo_id="lm-kit/bge-m3-gguf",
    filename="*F16.gguf",
    n_ctx=0,  # Model context is 8192
    n_gpu_layers=-1,
    n_batch=513,  # ← Any value larger than 512 (the default) causes a crash
    embedding=True,
    pooling_type=LLAMA_POOLING_TYPE_NONE,
    verbose=False
)

text = "Hello world" * 1000
embedding = embedder.embed(text)  # ← Crash 💥
len(embedding)
abetlen commented 1 month ago

Hey @lsorber, thank you for reporting this, a temporary workaround for now is to set n_ubatch as well as n_batch. ie:

from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama

embedder = Llama.from_pretrained(
    repo_id="lm-kit/bge-m3-gguf",
    filename="*F16.gguf",
    n_ctx=0,  # Model context is 8192
    n_gpu_layers=-1,
    n_batch=4096,
    n_ubatch=4096,
    embedding=True,
    pooling_type=LLAMA_POOLING_TYPE_NONE,
    verbose=False
)

text = "Hello world" * 1000
embedding = embedder.embed(text)
print(len(embedding))

More generally we need to update the defaults here for non-causal models to be more intuitive but also not break backwards compatibility.

lsorber commented 1 month ago

Good to hear there's already a workaround! That's a new feature in v0.3.0, right? Unfortunately, it seems the Metal wheels for v0.3.0 on the Releases page are broken.

abetlen commented 1 month ago

@lsorber yes it is, I'll try to get that wheel issue resolved as soon as possible!