Closed pfZhu closed 6 months ago
The prefix_function
that is generated is stateful - on the first time it receives a token sequence, it assumes the input token list to be the prompt that precedes the generation. It "marks" it, and starts the operation.
I'm not exactly sure what you're trying to do here, but if you first send it the decoded output of "ab" (without the curly brace), it will treat it as the prompt, and should behave accordingly afterwards.
@noamgat Hey, I uploaded a picture of my program to clarify my question, please check it.
Do you mean all the five times I call the function, the inputs are treated as the initial prompt? I guess only the first time it should be treated as the initial prompt, right?
I think the when the first time I build the prefix_allowed_tokens_fn
, one TokenEnforcer
instance is initialized, and only the first time calling the prefix_allowed_tokens_fn
, the input sequence is treated as the initial prompt.
The
prefix_function
that is generated is stateful - on the first time it receives a token sequence, it assumes the input token list to be the prompt that precedes the generation. It "marks" it, and starts the operation.I'm not exactly sure what you're trying to do here, but if you first send it the decoded output of "ab" (without the curly brace), it will treat it as the prompt, and should behave accordingly afterwards.
@noamgat Hi, please check the above updated question, looking forward to your reply!
The way the token enforcer works is as follows:
When a token sequence [t1, t2, ..., tn-1, tn]
arrives, it checks if [t1, ..., tn-1]
was already seen. If its not, it assumes that [t1,...,tn]
is a prompt sequence, and starts from the initial parsing state (in the json object case, expecting {
). If it does, it finds the parsing state for [t1,...,tn-1]
, advances it by the characters that tn
decodes to, and continues parsing from there.
Therefore, I think what you should be doing is first calling with the token sequence of decoded a,b
, initializing that as the "prompt", and then call a,b,{
like you currently do.
The way the token enforcer works is as follows:
When a token sequence
[t1, t2, ..., tn-1, tn]
arrives, it checks if[t1, ..., tn-1]
was already seen. If its not, it assumes that[t1,...,tn]
is a prompt sequence, and starts from the initial parsing state (in the json object case, expecting{
). If it does, it finds the parsing state for[t1,...,tn-1]
, advances it by the characters thattn
decodes to, and continues parsing from there.Therefore, I think what you should be doing is first calling with the token sequence of decoded
a,b
, initializing that as the "prompt", and then calla,b,{
like you currently do.
@noamgat Thank you for your reply. So it can also explain my another issue How does the function successfully deal with batch inputs #39, that this prefix_allowed_tokens_fn
stores or "remembers" all the input sequences in the whole input batch, and parses the state according to the incremental decoded new tokens, right?
Yes. This approach was used because it handles many use cases (batches, beams, multiple requests with one token enforcer instance, etc) with a simple unified approach.
If the answers here solve your issues, can you please close them?
Hi, I am very interested in this powerful projects, however I am not very clear about how the character parser works. The codes seem complicated to me because of the recursively function calling. Here is the code of an unexpected case that I tested to simulate a wired generation process and I wonder why the final returned results are wrong:
` import torch from lmformatenforcer import JsonSchemaParser from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn from transformers import AutoTokenizer
print('test prefix function') parser = JsonSchemaParser(None) tokenizer = AutoTokenizer.from_pretrained('./Llama-2-7b-Chat-GPTQ')
prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))
prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))
prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{', '"']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))
prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000]))
prefix = torch.tensor(tokenizer.convert_tokens_to_ids(['a', 'b', '{']), dtype=torch.int64) allowed_ids = prefix_function(batch_id = 0, sent = prefix) print('prefix:', tokenizer.convert_ids_to_tokens(prefix.numpy().tolist())) print('decode prefix:', tokenizer.decode(prefix.numpy().tolist())) print('allowed_ids[:1000]:', len(allowed_ids), allowed_ids[:1000]) print('allowed_tokens[:1000]:', tokenizer.convert_ids_to_tokens(allowed_ids[:1000])) `
The outputs are as follows:
It seems the final output allowed_tokens are wrong. Is it because that this simulated generation process leads to a wrong state of the character level parser? Thanks and looking forward to your reply!