turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.18k stars 233 forks source link

inference_json Example not working #513

Closed rjmehta1993 closed 1 week ago

rjmehta1993 commented 1 week ago

The example given to output JSON is not working. The only modification was changing the model from mistral to qwen2/llama3.

import sys, os

from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser
from pydantic import BaseModel, conlist
from typing import Literal
import json

model_dir = "/mnt/str/models/llama3-8b-instruct/"
config = ExLlamaV2Config(model_dir)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)

print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)

# Initialize the generator with all default parameters

generator = ExLlamaV2DynamicGenerator(
    model = model,
    cache = cache,
    tokenizer = tokenizer,
)

# JSON schema

class SuperheroAppearance(BaseModel):
    title: str
    issue_number: int
    year: int

class Superhero(BaseModel):
    name: str
    secret_identity: str
    superpowers: conlist(str, max_length = 5)
    first_appearance: SuperheroAppearance
    gender: Literal["male", "female"]

schema_parser = JsonSchemaParser(Superhero.schema())

# Create prompts with and without filters

i_prompts = [
    "Here is some information about Superman:\n\n",
    "Here is some information about Batman:\n\n",
    "Here is some information about Aquaman:\n\n",
]

prompts = []
filters = []

for p in i_prompts:
    prompts.append(p)
    filters.append(None)
    prompts.append(p)
    filters.append([
        ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer),
        ExLlamaV2PrefixFilter(model, tokenizer, ["{", " {"])
    ])

# Generate

print("Generating...")

outputs = generator.generate(
    prompt = prompts,
    filters = filters,
    filter_prefer_eos = True,
    max_new_tokens = 300,
    add_bos = True,
    stop_conditions = [tokenizer.eos_token_id],
    completion_only = True
)

# Print outputs:

for i in range(len(i_prompts)):

    print("---------------------------------------------------------------------------------")
    print(i_prompts[i].strip())
    print()
    print("Without filter:")
    print("---------------")
    print(outputs[i * 2])
    print()
    print("With filter:")
    print("------------")
    print(json.dumps(json.loads(outputs[i * 2 + 1]), indent = 4).strip())
    print()

OUTPUT:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 70
     66 # Generate
     68 print("Generating...")
---> 70 outputs = generator.generate(
     71     prompt = prompts,
     72     filters = filters,
     73     filter_prefer_eos = True,
     74     max_new_tokens = 2300,
     75     add_bos = True,
     76     stop_conditions = [tokenizer.eos_token_id],
     77     completion_only = True
     78 )
     80 # Print outputs:
     82 for i in range(len(i_prompts)):

File ~/exl19/lib/python3.10/site-packages/exllamav2/generator/dynamic.py:598, in ExLlamaV2DynamicGenerator.generate(self, prompt, max_new_tokens, min_new_tokens, seed, gen_settings, token_healing, encode_special_tokens, decode_special_tokens, stop_conditions, add_bos, abort_event, completion_only, filters, filter_prefer_eos, **kwargs)
    595 completions = [""] * batch_size
    597 while self.num_remaining_jobs():
--> 598     results = self.iterate()
    599     for r in results:
    600         idx = order[r["serial"]]

File ~/exl19/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/exl19/lib/python3.10/site-packages/exllamav2/generator/dynamic.py:853, in ExLlamaV2DynamicGenerator.iterate(self)
    848     self.iterate_gen(results, draft_tokens)
    850 # Regular generation
    851 
    852 else:
--> 853     self.iterate_gen(results)
    855 # Finished iteration
    857 return results

File ~/exl19/lib/python3.10/site-packages/exllamav2/generator/dynamic.py:1069, in ExLlamaV2DynamicGenerator.iterate_gen(self, results, draft_tokens)
   1066 job_logits = batch_logits[a:b, i:i+1, :]
   1067 if i == 0 and mt_sample:
   1068     next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
-> 1069     futures.popleft().result()
   1070 else:
   1071     next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
   1072     job.receive_logits(job_logits)

File /usr/local/lib/python3.10/concurrent/futures/_base.py:451, in Future.result(self, timeout)
    449     raise CancelledError()
    450 elif self._state == FINISHED:
--> 451     return self.__get_result()
    453 self._condition.wait(timeout)
    455 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

File /usr/local/lib/python3.10/concurrent/futures/thread.py:58, in _WorkItem.run(self)
     55     return
     57 try:
---> 58     result = self.fn(*self.args, **self.kwargs)
     59 except BaseException as exc:
     60     self.future.set_exception(exc)

File ~/exl19/lib/python3.10/site-packages/exllamav2/generator/dynamic.py:1541, in ExLlamaV2DynamicJob.receive_logits(self, logits)
   1537     else:
   1538         blocked_tokens = list(self.stop_tokens)
   1540 next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
-> 1541 ExLlamaV2Sampler.sample(
   1542     logits,
   1543     self.gen_settings,
   1544     self.sequences[0].sequence_ids.torch(),
   1545     self.rng.random(),
   1546     self.generator.tokenizer,
   1547     self.prefix_token if self.new_tokens == -1 else None,
   1548     self.return_top_tokens,
   1549     blocked_tokens = blocked_tokens,
   1550     filters = self.filters if self.new_tokens >= 0 else None,
   1551     filter_prefer_eos = self.filter_prefer_eos,
   1552     # sync = True
   1553 )
   1555 return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos

File ~/exl19/lib/python3.10/site-packages/exllamav2/generator/sampler.py:222, in ExLlamaV2Sampler.sample(logits, settings, sequence_ids, random, tokenizer, prefix_token, return_top_tokens, blocked_tokens, filters, filter_prefer_eos, sync)
    219 end_tokens = None
    220 for f in filters:
--> 222     pt, et = f.next()
    223     print(pt,et)
    224     if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt

File ~/exl19/lib/python3.10/site-packages/exllamav2/generator/filters/prefix.py:61, in ExLlamaV2PrefixFilter.next(self)
     57 rem_str = self.prefix_string[self.offset:]
     59 # Use prefix dict if string could be completed by one token
---> 61 if rem_str in prefix_to_ids:
     62     pass_tokens = set(prefix_to_ids[rem_str])
     63 else:

TypeError: unhashable type: 'list'
turboderp commented 1 week ago

The line numbers in your stack trace line up with v0.1.3 or earlier, and the prefix filter was updated in v0.1.4+ to take a list of prefixes rather than a string.

I think you have an old version of ExLlamaV2 installed and you're using examples from a later version of the repo.

waterangel91 commented 1 week ago

Hi, i would like to ask regarding the prefix feature, is there any technical benefit to use it instead of just appending the prefix to the end of my prompt.

Reason i am asking is because my existing code just appends the prefix to the prompt currently, i am trying to see if there is benefit to retrofit my code to use the new feature

turboderp commented 1 week ago

Mainly it's because the JSON filter needs to select { as its first token. So if you add the opening bracket to the prompt, the filter doesn't know that it's meant to start one level deep in the JSON schema. I.e. if you supply a prompt like And the answer, in JSON format, is: { the filter still has to constrain what follows to fit the JSON schema. Something like {"answer":"no"}, duplicating the opening bracket.

On the other hand, if you don't use the prefix constraint, LMFE will allow whitespace before the first opening bracket since that's still technically valid JSON. {"answer":"no"} and \n\n\n\n\n {"answer":"no"} both satisfy the same schema. I believe LMFE only allows up to a certain number of whitespace characters (?), but the amount you want is probably zero, and the prefix filter is a way of ensuring that.

You can also use it in other ways with longer prefixes or whatever, but in any case the point is that it imposes a restraint on top of the JSON filter, which adding text to the prompt wouldn't achieve.

Of course the prefix filter can be used on its own, too. If you want to chat with multiple bots but let the model decide whose turn it is to speak, you could use something like:

filters = ExLlamaV2PrefixFilter(model, tokenizer, [bot + ": " for bot in list of bots])

Possibilities are endless. [: