noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
994 stars 45 forks source link

Batch decoding with regex ends in 'no allowed tokens' state when pad token not decoded to empty string. #86

Closed JoshC8C7 closed 2 months ago

JoshC8C7 commented 3 months ago

When running the below example, I get 'Parser reached state with no allowed tokens'. I believe this is due to one example within the batch finishing and subsequently having pad tokens be generated for it, with these pad tokens being applied here (added) leading the parser to a non-finishing state, meaning there are no allowable tokens (whereas the eos token was available before erroneously applying this token).

Surely these pad tokens shouldn't be being applied, though? The issue is that they decode into </s> rather than an empty string, because self.decoder() runs with skip_special_tokens=False

What should the behaviour be after an eos token has been generated in one sequence but not the other with regex; should it not decode the subsequent pad tokens as nothing and continue to ForceStopParser?

I believe this is a bug but could just be a misconfiguration on my side.

from transformers import AutoModelForCausalLM, AutoTokenizer
from lmformatenforcer import RegexParser
from lmformatenforcer.integrations.transformers import (
    build_token_enforcer_tokenizer_data,
    build_transformers_prefix_allowed_tokens_fn,
)

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v0.6"
device = "cuda"

if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

regex_string = r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))"

tokenizer_data = build_token_enforcer_tokenizer_data(tokenizer)
parser = RegexParser(regex_string)

prompts = ["Generate a string", "Generate a strang"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to('cuda')
prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
outputs = model.generate(**inputs, prefix_allowed_tokens_fn=prefix_function)
for idx in range(len(prompts)):
    output_text = tokenizer.decode(outputs[idx], skip_special_tokens=True)
    print(output_text)

EDIT: I Can also get this using the prompts prompts = ["Generate a string", "Generate a strang", "generate the number one below 8"] with "google/gemma-2b-it"

noamgat commented 2 months ago

What you're specifying as the desired behavior sounds right. Does the attached sample consistently reproduce the problem? Please verify that it does, and if so I will look at it in the coming days.

On Fri, Apr 5, 2024 at 7:21 PM Josh C @.***> wrote:

When running the below example, I get 'Parser reached state with no allowed tokens'. I believe this is due to one example finishing and the other having pad tokens be generated for it, with these pad tokens being applied here https://github.com/noamgat/lm-format-enforcer/blob/142a5a6834b1119c308e0f3fefc92f8d5206bc61/lmformatenforcer/tokenenforcer.py#L150 (added) leading the parser to a non-finishing state, meaning there are no allowable tokens (whereas the eos token was available before erroneously applying this token).

Surely these pad tokens shouldn't be being applied, though? The issue is that they decode into '' rather than an empty string, because self.decoder() runs with https://github.com/noamgat/lm-format-enforcer/blob/fbcf5afb048ece85b549762843117cf158eb9f9e/lmformatenforcer/integrations/transformers.py#L71 skip_special_tokens=False

What should the behaviour be after an eos token has been generated in one sequence but not the other with regex; should it not decode the subsequent pad tokens as nothing and continue to ForceStopParser?

from transformers import AutoModelForCausalLM, AutoTokenizer from lmformatenforcer import RegexParser from lmformatenforcer.integrations.transformers import ( build_token_enforcer_tokenizer_data, build_transformers_prefix_allowed_tokens_fn, )

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v0.6" device = "cuda"

if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_id)

regex_string = r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))"

tokenizer_data = build_token_enforcer_tokenizer_data(tokenizer) parser = RegexParser(regex_string)

prompts = ["Generate a string", "Generate a strang"] inputs = tokenizer(prompts, return_tensors="pt", padding=True).to('cuda') prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser) outputs = model.generate(**inputs, prefix_allowed_tokens_fn=prefix_function) for idx in range(len(prompts)): output_text = tokenizer.decode(outputs[idx], skip_special_tokens=True) print(output_text)```

— Reply to this email directly, view it on GitHub https://github.com/noamgat/lm-format-enforcer/issues/86, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKFA2HMCHUWTSXUNPKZ72TY33FPJAVCNFSM6AAAAABFZNYAEOVHI2DSMVQWIX3LMV43ASLTON2WKOZSGIZDQNBQGQ3TAMQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

JoshC8C7 commented 2 months ago

Yes, it does with both tested models.