epfl-dlab / transformers-CFG

🤗 A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers
http://saibo-creator.xyz:7860/
MIT License
89 stars 16 forks source link

beam search doesn't work with transformers_cfg #9

Open minniekabra opened 8 months ago

minniekabra commented 8 months ago
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained("gpt2")

    # Load json grammar
    with open("examples/grammars/json.ebnf", "r") as file:
        grammar_str = file.read()
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Generate
    prefix1 = "This is a valid json string for http request:"
    prefix2 = "This is a valid json string for shopping cart:"
    input_ids = tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]

    output = model.generate(
        input_ids,
        do_sample=False,
        max_length=50,
        num_beams=1, #this can't be >1 
        logits_processor=[grammar_processor],
        repetition_penalty=5.0,
        num_return_sequences=1,
    )
Saibo-creator commented 8 months ago

Thanks for raising this issue, the support for beam search is yet in progress.

The error message is below

ValueError: All stacks are empty, so the only token accepted is EOS(2), but got 539
Saibo-creator commented 6 months ago

Here, I describe how to integrate support for beam search with grammar-constrained decoding in case we have volunteer wants to contribute :)

At present, our library utilizes a logit_processor to influence the decoding process. This processor uses an underlying parser to determine permissible tokens at each step.

While effective for various decoding/sampling methods, it doesn't suit constrained beam search.

The incompatibility of the constrained logit processor with beam search is complex and relates to the mechanics of beam search itself. However, this detail is not central to this feature, as our focus is on employing the Constraint class from Hugging Face.

Credit goes to @chanwkimlab for developing the constrained beam search and providing a robust abstraction along with a comprehensive blog post: https://huggingface.co/blog/constrained-beam-search

The procedure involves:

  1. Creating class GrammarConstraint and conducting tests.
  2. Using GrammarConstraint instead of GrammarConstraintLogitProcessor during inference and testing.
class GrammarConstraint(Constraint):

    def __init__(self, token_ids: List[int]):
        super(Constraint, self).__init__()
        ...

    def advance(self):
           ...

    def does_advance(self, token_id: int):
           ...

    def update(self, token_id: int):
           ...

    def reset(self):
        self.completed = False
        self.fulfilled_idx = 0

    def remaining(self):
           # For grammar constrained decoding, determining the exact number of remaining tokens may be challenging, but it should not pose a significant issue. 

Here are some example implementation of Constraints in HF library: https://github.com/huggingface/transformers/blob/c60749d6a67d223d65a2fb6105c2459f3469a30d/src/transformers/generation/beam_constraints.py#L129

That's it !

HichemAK commented 3 months ago

Hello! Is this still an active issue, or does a workaround have been found?

I can give a shot at coding the GrammarConstraint class

Saibo-creator commented 3 months ago

Hey, the feature is not yet done. Go ahead and I will be happy to merge it :)

On 13 Aug 2024, at 20:10, Hichem Ammar Khodja @.***> wrote:



Hello! Is this still an active issue, or does a workaround have been found?

I can give a shot at coding the GrammarConstraint class

— Reply to this email directly, view it on GitHubhttps://github.com/epfl-dlab/transformers-CFG/issues/9#issuecomment-2286219121, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AMXLMUAG6NVDGAQ64VPH5WLZRIAVRAVCNFSM6AAAAABDQTPTD2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBWGIYTSMJSGE. You are receiving this because you commented.Message ID: @.***>

HichemAK commented 2 months ago

Hello, unfortunately I couldn't make it work, this constraint feature lacks documentation and it's difficult to understand how it works behind the scenes. When coding, I tried to follow the same format as the constraints found in the transformers library.

transformers version: 4.44.0

Here is my best attempt:


from transformers.generation.beam_constraints import Constraint
from transformers_cfg.grammar_utils import IncrementalTokenRecognizer, IncrementalGrammarConstraint
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class GrammarConstraint(Constraint):
    def __init__(self, token_recognizer : IncrementalTokenRecognizer):
        super(Constraint, self).__init__()
        self.token_recognizer = token_recognizer
        self.current_state = self.token_recognizer.string_recognizer.get_initial_parsing_state()
        self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
        self.completed = False
        self.seqlen = float('inf')
        self.tokens = []

    @property
    def text(self):
        return self.token_recognizer.tokenizer.decode(self.tokens)

    def advance(self):
        # Return the next set of tokens that would be accepted by the current grammar state
        if self.completed:
            return []
        acceptance = self.valid_tokens
        return acceptance.nonzero(as_tuple=False).squeeze(-1).tolist()

    def does_advance(self, token_id: int):
        # Check if the given token_id is accepted by the current grammar state
        acceptance = self.valid_tokens
        return acceptance[token_id]

    def update(self, token_id: int):
        # Update the state with the given token_id and return the progress indicators
        if self.does_advance(token_id):
            new_state = self.token_recognizer._update_state_with_token_id(token_id, self.current_state)
            self.current_state = new_state

            stepped = True
            completed = not bool(new_state.stacks)  # If stacks are empty, the constraint is completed
            self.tokens.append(token_id)
            if not completed:
                self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
            reset = False
        else:
            # The token_id was not accepted, reset the state
            self.reset()
            stepped = False
            completed = False
            reset = True

        self.completed = completed
        return stepped, completed, reset

    def reset(self):
        # Reset the state of this constraint to its initialization
        self.current_state = self.token_recognizer.string_recognizer.get_initial_parsing_state()
        self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
        self.completed = False
        self.tokens.clear()

    def remaining(self):
        # Return the number of remaining steps; this is more complex for a grammar constraint
        # and might not be easily quantifiable. For simplicity, we return 1 if not completed.
        return 0 if self.completed else 1

    def copy(self, stateful=False):
        # Create a new instance of this constraint
        new_constraint = GrammarConstraint(
            self.token_recognizer
        )
        if stateful:
            new_constraint.current_state = self.current_state
            new_constraint.valid_tokens = self.valid_tokens
            new_constraint.completed = self.completed
            new_constraint.tokens = self.tokens.copy()
        return new_constraint

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint

if __name__ == "__main__":
    # Detect if GPU is available, otherwise use CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model_id = "gpt2"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    # Load json grammar
    with open("tuples.ebnf", "r") as file:
        grammar_str = file.read()

    token_recognizer = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True)
    grammar = GrammarConstraint(token_recognizer)

    model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16)  # Load model to defined device
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    # Generate
    prefix1 = "Tuples:"
    input_ids = tokenizer([prefix1], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"].to(device)
    max_new_tokens = 50
    # grammar.seqlen = max_new_tokens
    output = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        constraints=[grammar],
        num_beams=3,
        do_sample=False
    )
    # decode output
    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)

Here is the content of tuples.ebnf:

root   ::= triple triple

triple ::= "[" object object object "]"

object ::= "A" | "B" | "C"
Saibo-creator commented 2 months ago

@HichemAK Thanks for your effort! After diving deeper into beam search, I found that the implementation of constrained beam search in HF is quite convoluted and too closely tied to existing constraints, making it not general enough. Trying to implement it directly is indeed not the best way. The results coding will be very ugly and ineffient.

It might be better to avoid that approach and work directly with beam search, but that would require modifying the HF codebase.

I’ve sketched out how I plan to implement this beam search. For those interested, feel free to check it out. I’ll likely start working on it myself in the next few days.

GitHub Commit