Open pseudomonas opened 1 year ago
Here's an example script for your use case. You are welcome to create a PR and make the code more modular.
import torch
import itertools
from awesome_align import modeling
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from awesome_align.tokenization_bert import BertTokenizer
from awesome_align.tokenization_utils import PreTrainedTokenizer
from awesome_align.modeling_utils import PreTrainedModel
class AwesomeAligner:
def __init__(self, model_name_or_path='bert-base-multilingual-cased'):
config_class, model_class, tokenizer_class = BertConfig, BertForMaskedLM, BertTokenizer
config = config_class.from_pretrained(model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
modeling.PAD_ID = tokenizer.pad_token_id
modeling.CLS_ID = tokenizer.cls_token_id
modeling.SEP_ID = tokenizer.sep_token_id
self.model = model_class.from_pretrained(
model_name_or_path,
from_tf=bool(".ckpt" in model_name_or_path),
config=config
)
self.tokenizer = tokenizer
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def align(self, line):
src, tgt = line.split(' ||| ')
assert src.rstrip() != '' and tgt.rstrip() != ''
sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
token_src, token_tgt = [self.tokenizer.tokenize(word) for word in sent_src], [self.tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [self.tokenizer.convert_tokens_to_ids(x) for x in token_src], [self.tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
ids_src, ids_tgt = self.tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', max_length=self.tokenizer.max_len)['input_ids'], self.tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', max_length=self.tokenizer.max_len)['input_ids']
assert len(ids_src[0]) != 2 or len(ids_tgt[0]) != 2
bpe2word_map_src = []
for i, word_list in enumerate(token_src):
bpe2word_map_src += [i for x in word_list]
bpe2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
bpe2word_map_tgt += [i for x in word_list]
word_aligns = self.model.get_aligned_word(ids_src, ids_tgt, [bpe2word_map_src], [bpe2word_map_tgt], self.device, 0, 0, test=True)[0]
return word_aligns
model = AwesomeAligner()
print(model.align('order , please . ||| a le ordre .'))
Ideally I'd like to keep the model in memory and call it with something approaching the syntax used by Simalign:
Is there a way to do this? The use-case is where a user is requesting alignments from a gui, so they can't be pre-computed in a batch.