allenai / scibert

A BERT model for scientific text.
https://arxiv.org/abs/1903.10676
Apache License 2.0
1.47k stars 214 forks source link

Relation Extraction in PyTorch #107

Open PetterBerntsson opened 3 years ago

PetterBerntsson commented 3 years ago

Hello, We have our model fine-tuned over the chemprot dataset you've provided, and have now downloaded the model to run predictions locally. However, loading and running the model seems to present several issues:

First of all, there are 13 classes the model can predict. But running it over the whole dataset it seems the model can't reproduce the metrics from training (even when running over the training set). The model seems not to be able to distinguish the classes. For example, the most predicted index for each class is as follows:

(That's another question, where could we find which index represents which text-label?)

It seems also when loading that we need to rename all the keys, and we haven't found a way to load the model simply (as you see from this hacky solution). The fine-tuned model was the AllenNLP scivocab-uncased, would AllenNLP be a dependency when running our models in PyTorch?

If you have suggestions, or find a problem with our code, it would be very helpful if you could point it out. We have so far assumed that the code works, as it is runnable and produces consistent results and that fine-tuning changes these results. The metrics however are not as good as during the fine-tuning using your repository, but no amount of fine-tuning seems to solve the issue.

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
checkpoint = torch.load("./scibert_model/model2/pytorch_model.bin", map_location=device)

for key in checkpoint:
    name = str(key).replace("text_field_embedder.token_embedder_bert.bert_model.", "")
    renamed_dict[name] = checkpoint.get(key)

cfw = renamed_dict.pop("classifier_feedforward.weight")
cfb = renamed_dict.pop("classifier_feedforward.bias")

out_layer = torch.nn.Linear(768, 13)
with torch.no_grad():
    out_layer.weight.copy_(cfw)
    out_layer.bias.copy_(cfb)

model.load_state_dict(renamed_dict)
model.to(device)
model.eval()

with torch.no_grad():

text = "Beta-1,4-galactosyltransferase I (beta4Gal-T1) normally transfers Gal from UDP-Gal to GlcNAc in the presence of Mn(2+) ion (Gal-T activity) and also transfers Glc from << UDP-Glc >> to GlcNAc ([[ Glc-T ]] activity), albeit at only 0.3% efficiency."

    encoded = tokenizer.encode_plus(
        text,
        max_length=128,
        add_special_tokens=True,
        return_token_type_ids=False,
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

   input_ids = encoded['input_ids'].to(device)
    attention_mask = encoded['attention_mask'].to(device)
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    preds = torch.softmax(out_layer(out[1]), dim=1)[0]
    print(torch.argmax(preds))

Thank you in advance