allenai / scibert

A BERT model for scientific text.
https://arxiv.org/abs/1903.10676
Apache License 2.0
1.47k stars 214 forks source link

Scibert for text classification #95

Open InesArous opened 4 years ago

InesArous commented 4 years ago

Hi,

Thanks for your awesome work! I would like to use SciBERT for text classification. I managed to get some results by directly using the script train_allennlp_local.sh with modifying the task field as described in the readme file. However, I am not able to get the same results using Huggingface's framework. Is there are any available resources/tutorials on how to make the equivalence between the two? Thanks!

amandalmia14 commented 3 years ago

@InesArous I was able to train / finetune the BERT for text classification, however if I replace the actual bert sequence classification to below and change the tokenizer,

from: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(label_dict), output_attentions=False, output_hidden_states=False) to: tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

I get an error while training the model,

` TypeError Traceback (most recent call last)

in 17 } 18 ---> 19 outputs = model(**inputs) 20 21 loss = outputs[0] d:\multi_class_text_classification\venv\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), TypeError: forward() got an unexpected keyword argument 'labels' `
ibeltagy commented 3 years ago

@InesArous, you can try to follow one of the classification examples in the HF code https://github.com/huggingface/transformers/tree/master/examples/text-classification, maybe the run_pl_glue.py one.

ibeltagy commented 3 years ago

@amandalmia14, you need to use AutoModelForSequenceClassification instead of AutoModel