epfl-dlab / transformers-CFG

🤗 A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers
http://saibo-creator.xyz:7860/
MIT License
67 stars 9 forks source link

beam search doesn't work with transformers_cfg #9

Open minniekabra opened 5 months ago

minniekabra commented 5 months ago
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained("gpt2")

    # Load json grammar
    with open("examples/grammars/json.ebnf", "r") as file:
        grammar_str = file.read()
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Generate
    prefix1 = "This is a valid json string for http request:"
    prefix2 = "This is a valid json string for shopping cart:"
    input_ids = tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]

    output = model.generate(
        input_ids,
        do_sample=False,
        max_length=50,
        num_beams=1, #this can't be >1 
        logits_processor=[grammar_processor],
        repetition_penalty=5.0,
        num_return_sequences=1,
    )
Saibo-creator commented 5 months ago

Thanks for raising this issue, the support for beam search is yet in progress.

The error message is below

ValueError: All stacks are empty, so the only token accepted is EOS(2), but got 539
Saibo-creator commented 3 months ago

Here, I describe how to integrate support for beam search with grammar-constrained decoding in case we have volunteer wants to contribute :)

At present, our library utilizes a logit_processor to influence the decoding process. This processor uses an underlying parser to determine permissible tokens at each step.

While effective for various decoding/sampling methods, it doesn't suit constrained beam search.

The incompatibility of the constrained logit processor with beam search is complex and relates to the mechanics of beam search itself. However, this detail is not central to this feature, as our focus is on employing the Constraint class from Hugging Face.

Credit goes to @chanwkimlab for developing the constrained beam search and providing a robust abstraction along with a comprehensive blog post: https://huggingface.co/blog/constrained-beam-search

The procedure involves:

  1. Creating class GrammarConstraint and conducting tests.
  2. Using GrammarConstraint instead of GrammarConstraintLogitProcessor during inference and testing.
class GrammarConstraint(Constraint):

    def __init__(self, token_ids: List[int]):
        super(Constraint, self).__init__()
        ...

    def advance(self):
           ...

    def does_advance(self, token_id: int):
           ...

    def update(self, token_id: int):
           ...

    def reset(self):
        self.completed = False
        self.fulfilled_idx = 0

    def remaining(self):
           # For grammar constrained decoding, determining the exact number of remaining tokens may be challenging, but it should not pose a significant issue. 

Here are some example implementation of Constraints in HF library: https://github.com/huggingface/transformers/blob/c60749d6a67d223d65a2fb6105c2459f3469a30d/src/transformers/generation/beam_constraints.py#L129

That's it !