noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.01k stars 46 forks source link

Returned wrong allowed tokens #37

Closed pfZhu closed 6 months ago

pfZhu commented 7 months ago

Hi, I am very interested in this powerful projects, however I am not very clear about how the character parser works. The codes seem complicated to me because of the recursively function calling. Here is the code of an unexpected case that I tested to simulate a wired generation process and I wonder why the final returned results are wrong:

` import torch from lmformatenforcer import JsonSchemaParser from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn from transformers import AutoTokenizer

print('test prefix function') parser = JsonSchemaParser(None) tokenizer = AutoTokenizer.from_pretrained('./Llama-2-7b-Chat-GPTQ')

prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)

prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))

prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))

prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{', '"']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))

prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))

prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000])) `

The outputs are as follows:

image

It seems the final output allowed_tokens are wrong. Is it because that this simulated generation process leads to a wrong state of the character level parser? Thanks and looking forward to your reply!

noamgat commented 7 months ago

The prefix_function that is generated is stateful - on the first time it receives a token sequence, it assumes the input token list to be the prompt that precedes the generation. It "marks" it, and starts the operation.

I'm not exactly sure what you're trying to do here, but if you first send it the decoded output of "ab" (without the curly brace), it will treat it as the prompt, and should behave accordingly afterwards.

pfZhu commented 7 months ago

@noamgat Hey, I uploaded a picture of my program to clarify my question, please check it.

Do you mean all the five times I call the function, the inputs are treated as the initial prompt? I guess only the first time it should be treated as the initial prompt, right?

I think the when the first time I build the prefix_allowed_tokens_fn, one TokenEnforcer instance is initialized, and only the first time calling the prefix_allowed_tokens_fn, the input sequence is treated as the initial prompt.

pfZhu commented 6 months ago

The prefix_function that is generated is stateful - on the first time it receives a token sequence, it assumes the input token list to be the prompt that precedes the generation. It "marks" it, and starts the operation.

I'm not exactly sure what you're trying to do here, but if you first send it the decoded output of "ab" (without the curly brace), it will treat it as the prompt, and should behave accordingly afterwards.

@noamgat Hi, please check the above updated question, looking forward to your reply!

noamgat commented 6 months ago

The way the token enforcer works is as follows:

When a token sequence [t1, t2, ..., tn-1, tn] arrives, it checks if [t1, ..., tn-1] was already seen. If its not, it assumes that [t1,...,tn] is a prompt sequence, and starts from the initial parsing state (in the json object case, expecting {). If it does, it finds the parsing state for [t1,...,tn-1], advances it by the characters that tn decodes to, and continues parsing from there.

Therefore, I think what you should be doing is first calling with the token sequence of decoded a,b, initializing that as the "prompt", and then call a,b,{ like you currently do.

pfZhu commented 6 months ago

The way the token enforcer works is as follows:

When a token sequence [t1, t2, ..., tn-1, tn] arrives, it checks if [t1, ..., tn-1] was already seen. If its not, it assumes that [t1,...,tn] is a prompt sequence, and starts from the initial parsing state (in the json object case, expecting {). If it does, it finds the parsing state for [t1,...,tn-1], advances it by the characters that tn decodes to, and continues parsing from there.

Therefore, I think what you should be doing is first calling with the token sequence of decoded a,b, initializing that as the "prompt", and then call a,b,{ like you currently do.

@noamgat Thank you for your reply. So it can also explain my another issue How does the function successfully deal with batch inputs #39, that this prefix_allowed_tokens_fn stores or "remembers" all the input sequences in the whole input batch, and parses the state according to the incremental decoded new tokens, right?

noamgat commented 6 months ago

Yes. This approach was used because it handles many use cases (batches, beams, multiple requests with one token enforcer instance, etc) with a simple unified approach.

noamgat commented 6 months ago

If the answers here solve your issues, can you please close them?