jind11 / TextFooler

A Model for Natural Language Attack on Text Classification and Inference
MIT License
485 stars 79 forks source link

Regarding No. of classes in AG News #11

Closed RishabhMaheshwary closed 4 years ago

RishabhMaheshwary commented 4 years ago

On AG news I am getting the following:

Error(s) in loading state_dict for BertForSequenceClassification: size mismatch for classifier.weight: copying a param with shape torch.Size([4, 768]) from checkpoint, the shape in current model is torch.Size([2, 768]). size mismatch for classifier.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).

Is it because of 2 classes instead of 4 as in AG news ? If yes, than how you mapped the 4 class labels to 2 class labels ?

EDIT : The nclasses was set to 2 by default in attack_classification.py