neulab / awesome-align

A neural word aligner based on multilingual BERT
https://arxiv.org/abs/2101.08231
BSD 3-Clause "New" or "Revised" License
325 stars 47 forks source link

Repeated single-sentence inferences on an in-memory model? #58

Open pseudomonas opened 1 year ago

pseudomonas commented 1 year ago

Ideally I'd like to keep the model in memory and call it with something approaching the syntax used by Simalign:

myaligner = SentenceAligner(model="model_path", token_type="bpe", **model_parameters)

# ... and later ...

while True:
    alignments = myaligner.get_word_aligns(src_sentence_as_list_of_strings, trg_sentence_as_list_of_strings)
    # ... wait until next request comes in ...

Is there a way to do this? The use-case is where a user is requesting alignments from a gui, so they can't be pre-computed in a batch.

zdou0830 commented 1 year ago

Here's an example script for your use case. You are welcome to create a PR and make the code more modular.

import torch
import itertools
from awesome_align import modeling
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from awesome_align.tokenization_bert import BertTokenizer
from awesome_align.tokenization_utils import PreTrainedTokenizer
from awesome_align.modeling_utils import PreTrainedModel

class AwesomeAligner:
    def __init__(self, model_name_or_path='bert-base-multilingual-cased'):
        config_class, model_class, tokenizer_class = BertConfig, BertForMaskedLM, BertTokenizer
        config = config_class.from_pretrained(model_name_or_path)
        tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
        modeling.PAD_ID = tokenizer.pad_token_id
        modeling.CLS_ID = tokenizer.cls_token_id
        modeling.SEP_ID = tokenizer.sep_token_id
        self.model = model_class.from_pretrained(
            model_name_or_path,
            from_tf=bool(".ckpt" in model_name_or_path),
            config=config
        )
        self.tokenizer = tokenizer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def align(self, line):
        src, tgt = line.split(' ||| ')
        assert src.rstrip() != '' and tgt.rstrip() != ''

        sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
        token_src, token_tgt = [self.tokenizer.tokenize(word) for word in sent_src], [self.tokenizer.tokenize(word) for word in sent_tgt]
        wid_src, wid_tgt = [self.tokenizer.convert_tokens_to_ids(x) for x in token_src], [self.tokenizer.convert_tokens_to_ids(x) for x in token_tgt]

        ids_src, ids_tgt = self.tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', max_length=self.tokenizer.max_len)['input_ids'], self.tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', max_length=self.tokenizer.max_len)['input_ids']
        assert len(ids_src[0]) != 2 or len(ids_tgt[0]) != 2

        bpe2word_map_src = []
        for i, word_list in enumerate(token_src):
            bpe2word_map_src += [i for x in word_list]
        bpe2word_map_tgt = []
        for i, word_list in enumerate(token_tgt):
            bpe2word_map_tgt += [i for x in word_list]

        word_aligns = self.model.get_aligned_word(ids_src, ids_tgt, [bpe2word_map_src], [bpe2word_map_tgt], self.device, 0, 0, test=True)[0]

        return word_aligns

model = AwesomeAligner()
print(model.align('order , please .   ||| a le ordre .'))