epfl-dlab / transformers-CFG

🤗 A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers
http://saibo-creator.xyz:7860/
MIT License
83 stars 15 forks source link

Enhancement/refactor grammar encoding #99

Open nathanrchn opened 6 days ago

nathanrchn commented 6 days ago

This PR refactors the parser logic by introducing multiple classes to represent different parts of the grammar:

  1. GrammarElement: An abstract class representing a single element. It has two subclasses:

    • TerminatedElement: Represents literals, ranges, and similar constructs.
    • ReferenceElement: Represents references, with a single attribute: reference_id.
  2. AlternativeElements: A class representing an alternative, containing a list of symbols that is itself a list of GrammarElementobjects. Thesymbolsare useful for repetition operators that apply not just on the previousGrammarElement` but on the entire previous symbol.

  3. GrammarRule: A class representing a rule, containing a list of AlternativeElements objects.

Each of these classes implements the Codable interface and provides an implementation of the serialize method:

def serialize(self) -> List[int]:
    # Implementation details

The serialize method generates a list of integers, maintaining consistency with the original grammar_encoding list and the rest of the codebase.

I am currently trying to add a graph method to the ParseState class to be able to create nice graphics to visualise the grammar. Update: The graph method is now fully functional.

nathanrchn commented 3 days ago

I tested my changes against the old parser to ensure full compatibility. To accomplish this, I compared the results of processing all EBNF files in the repository:

def test_refactor_parser(self):
        # get all the ebnf files in examples/grammars and subdirectories
        import os
        from transformers_cfg.old_parser import parse_ebnf as old_parse_ebnf

        files = []
        for root, _, filenames in os.walk("examples/grammars"):
            for filename in filenames:
                if filename.endswith(".ebnf"):
                    files.append(os.path.join(root, filename))

        for file in files:
            with open(file, "r") as f:
                src = f.read()

            old_state = old_parse_ebnf(src)
            new_state = parse_ebnf(src)
            self.assertListEqual(old_state.grammar_encoding, new_state.grammar_encoding, f"The grammar encoding of {file} is different")

Here the old_parser.py file is just the original parser.py file from the main branch.