Open twadada opened 4 years ago
I recently tried it as follows.
Take model.pt
from https://dl.fbaipublicfiles.com/fairseq/models/spanbert_large_with_head.tar.gz (#44) and config.json
from https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf.tar.gz, and put them in a directory (e.g. spanbert_large_with_head
).
Then, convert model.pt
into pytorch_model.bin
using the following code so that it can be used from huggingface/transformers.
from collections import OrderedDict
import torch
def convert(input_path, output_path):
state_dict = OrderedDict()
for k, v in torch.load(input_path).items():
if k.startswith('decoder.'):
state_dict[k[8:]] = v
torch.save(state_dict, output_path)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('input_path')
parser.add_argument('output_path')
args = parser.parse_args()
convert(args.input_path, args.output_path)
from transformers import BertTokenizer, BertForMaskedLM
tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
model = BertForMaskedLM.from_pretrained('./spanbert_large_with_head/')
seq = "Super Bowl 50 was {m} {m} {m} {m} to determine the champion".format(m=tokenizer.mask_token)
inputs = tokenizer(seq, return_tensors="pt")
outputs = model(**inputs).logits
print(tokenizer.decode([y if x == tokenizer.mask_token_id else x for x, y
in zip(inputs.input_ids[0], outputs.argmax(dim=2)[0])]))
In my env with transformers==4.1.1, this produced [CLS] Super Bowl 50 was the first football game to determine the champion [SEP]
.
The Hugging Face tutorial (https://huggingface.co/transformers/task_summary.html#masked-language-modeling) is also helpful.
I hope this will work for you.
@chantera I'm not sure why, but this was my result: "[CLS] Super Bowl 50 was Trilogy trailers Singers 231 to determine the champion [SEP]" Do you have any idea as to why?
Hi ! have you got the base mode with head?
No, I haven't -- I don't think it has been published anywhere.
This is indeed helpful. Another thing to figure out is using the masking strategy of SpanBERT. Have you figured that out too? SpanBERT is different than the original BERT in two main aspects: 1) MLM prediction, the SBO head, 2) masking spans. The SBO head is available here, but I'm trying to figure out how to add the masking strategy too
I've found that you have shared the large model file with LM/SBO head here (https://dl.fbaipublicfiles.com/fairseq/models/spanbert_large_with_head.tar.gz); would you also provide the file for the base model?
Also, would you kindly provide a code that performs MLM prediction using SBO head (similar to BertPairTargetPredictionHead at pretraining/fairseq/models/pair_bert.py)? I'm curious about how it compares to the standard MLM.