simonw / llm-llama-cpp

LLM plugin for running models using llama.cpp
Apache License 2.0
136 stars 19 forks source link

Support for logprobs #17

Open simonw opened 11 months ago

simonw commented 11 months ago

Similar to this feature:

It looks like llama-cpp and llama-cpp-python do have support for outputting logprobs as well.

simonw commented 11 months ago

Got a prototype working:

diff --git a/llm_llama_cpp.py b/llm_llama_cpp.py
index f2fc977..44097e5 100644
--- a/llm_llama_cpp.py
+++ b/llm_llama_cpp.py
@@ -1,11 +1,16 @@
 import click
 import httpx
-import io
 import json
 import llm
 import os
 import pathlib
 import sys
+from typing import Optional
+try:
+    from pydantic import field_validator, Field  # type: ignore
+except ImportError:
+    from pydantic.fields import Field
+    from pydantic.class_validators import validator as field_validator  # type: ignore [no-redef]

 try:
     from llama_cpp import Llama
@@ -91,8 +96,8 @@ def register_commands(cli):
     )
     def download_model(url, aliases, llama2_chat):
         "Download and register a model from a URL"
-        if not url.endswith(".bin"):
-            raise click.BadParameter("URL must end with .bin")
+        if not url.endswith(".bin") and not url.endswith(".gguf"):
+            raise click.BadParameter("URL must end with .bin or .gguf")
         with httpx.stream("GET", url, follow_redirects=True) as response:
             total_size = response.headers.get("content-length")

@@ -170,6 +175,11 @@ def register_commands(cli):
 class LlamaModel(llm.Model):
     class Options(llm.Options):
         verbose: bool = False
+        logprobs: Optional[int] = Field(
+            description="Include the log probabilities of most likely N per token",
+            default=None,
+            le=5,
+        )

     def __init__(self, model_id, path, is_llama2_chat: bool = False):
         self.model_id = model_id
@@ -226,7 +236,11 @@ class LlamaModel(llm.Model):
     def execute(self, prompt, stream, response, conversation):
         with SuppressOutput(verbose=prompt.options.verbose):
             llm_model = Llama(
-                model_path=self.path, verbose=prompt.options.verbose, n_ctx=4000
+                model_path=self.path,
+                verbose=prompt.options.verbose,
+                n_ctx=4000,
+                n_gpu_layers=1,
+                logits_all=bool(prompt.options.logprobs)
             )
             if self.is_llama2_chat:
                 prompt_bits = self.build_llama2_chat_prompt(prompt, conversation)
@@ -234,13 +248,19 @@ class LlamaModel(llm.Model):
                 response._prompt_json = {"prompt_bits": prompt_bits}
             else:
                 prompt_text = prompt.prompt
-            stream = llm_model(prompt_text, stream=True)
+            kwargs = {}
+            if prompt.options.logprobs is not None:
+                kwargs["logprobs"] = prompt.options.logprobs
+            stream = llm_model(prompt_text, stream=True, **kwargs)
+            bits = []
             for item in stream:
                 # Each item looks like this:
                 # {'id': 'cmpl-00...', 'object': 'text_completion', 'created': .., 'model': '/path', 'choices': [
                 #   {'text': '\n', 'index': 0, 'logprobs': None, 'finish_reason': None}
                 # ]}
+                bits.append(item)
                 yield item["choices"][0]["text"]
+            response.response_json = {"bits": bits}

 def human_size(num_bytes):

This resulted in the following JSON being written to the database (truncated):

{
    "bits": [
        {
            "id": "cmpl-9f6acebf-bb00-4e92-8c04-5c6b7c6970a9",
            "object": "text_completion",
            "created": 1695235636,
            "model": "/Users/simon/Library/Application Support/io.datasette.llm/llama-cpp/models/llama-2-13b.Q8_0.gguf",
            "choices": [
                {
                    "text": "\n",
                    "index": 0,
                    "logprobs": {
                        "tokens": [
                            "\n"
                        ],
                        "text_offset": [
                            552
                        ],
                        "token_logprobs": [
                            -0.4442370610922866
                        ],
                        "top_logprobs": [
                            {
                                "\n": -0.4442370610922866,
                                " [": -3.552636452266603,
                                " <<": -3.6009209829062514
                            }
                        ]
                    },
                    "finish_reason": null
                }
            ]
        },
        {
            "id": "cmpl-9f6acebf-bb00-4e92-8c04-5c6b7c6970a9",
            "object": "text_completion",
            "created": 1695235636,
            "model": "/Users/simon/Library/Application Support/io.datasette.llm/llama-cpp/models/llama-2-13b.Q8_0.gguf",
            "choices": [
                {
                    "text": "say",
                    "index": 0,
                    "logprobs": {
                        "tokens": [
                            "say"
                        ],
                        "text_offset": [
                            553
                        ],
                        "token_logprobs": [
                            -2.2131922683492076
                        ],
                        "top_logprobs": [
                            {
                                "\n": -1.159477516151942,
                                "<": -1.7465556106343638,
                                "say": -2.2131922683492076
                            }
                        ]
                    },
                    "finish_reason": null
                }
            ]
        },
simonw commented 11 months ago

I need to figure out how to condense that format a bit, like I did for the OpenAI ones here: https://github.com/simonw/llm/blob/bf229945fe57036fa75e8105e59d9e506a720156/llm/default_plugins/openai_models.py#L387-L400