huggingface / tokenizers

💥 Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
8.68k stars 746 forks source link

Building a tokenzier for tokenizing Java code #1446

Closed nimanthadilz closed 3 months ago

nimanthadilz commented 5 months ago

Hi, I am using the tokenizers library to build a tokenizer that can be used to tokenize Java code into valid Java tokens. This tokenizer will be used in a transformer model which can fix bugs in Java code.

So far, what I've done is, I have used the javalang library to identify valid Java tokens. I've created a custom pre-tokenizer which uses javalang to split the input into valid Java code. As the model of the tokenizer, I've used WordLevel since I don't need subword tokenization.

tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]"))

tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JavaLangImprovedPreTokenizer())

tokenizer.decoder = decoders.Decoder.custom(CustomDecoder())

special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]

trainer = trainers.WordLevelTrainer(special_tokens=special_tokens, show_progress=True)

This tokenizer can now tokenize a text of Java code into valid Java tokens. For example:

tok.tokenize("private int getAge() { return age; }")

# Output
# ['private', 'int', 'getAge', '(', ')', '{', 'return', 'age', ';', '}']

I need to split the identifiers like getAge (method names, variable names) which are camelCase into separate tokens. When I do that, I have to add some symbol (like "#") to represent that splitted tokens are originally one token. So that I can later concatenate them.

But I can't find a way to do this in my custom pre-tokenizer. There we are getting a NormalizedString as the input. I tried to add a symbol when splitting camelCase tokens but didn't work. Is there a way to achieve that or is there a better way to do this than what I've done?

My custom pre-tokenizer is below:

class JavaLangImprovedPreTokenizer:
    def javalang_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        string = str(normalized_string)
        javalang_tokens = list(javalang.tokenizer.tokenize(string))
        splits = []
        original_pos = 0
        for javalang_token in javalang_tokens:
            length = len(javalang_token.value)
            while str(javalang_token.value) != string[original_pos:original_pos+length] and original_pos < len(string):
                original_pos += 1
            if original_pos >= len(string):
                raise ValueError(f"Could not find token \"{javalang_token.value}\" in string \"{string}\"")

            token_type = type(javalang_token).__name__

            if token_type == "DecimalInteger" or token_type == "DecimalFloatingPoint":
                integer = javalang_token.value
                for i in range(len(integer)):
                    splits.append(normalized_string[original_pos+i:original_pos+i+1])
            else:
                splits.append(normalized_string[original_pos:original_pos+length])
            original_pos += length
        return splits

    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.javalang_split)
ArthurZucker commented 4 months ago

Sorry for coming back late on this! The Bert has a similar process and uses WordPiece with continuing_subword_prefix="##" which is probably what you are looking for no?

github-actions[bot] commented 3 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.