Closed NISH1001 closed 1 year ago
Using train-watbertv6-squad-2ep
model mapped only the QA head to RobertaForQuestionAnswering
and using evalem.evaluators.QAEvaluator()
on squad v2
The model mapping logic looks like:
def load_model(
config_path,
pth_path,
) -> RobertaForQuestionAnswering:
config = AutoConfig.from_pretrained(config_path)
state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
model = RobertaForQuestionAnswering(config)
required_keys = model.state_dict().keys()
state_dict_new = OrderedDict()
for key, val in state_dict.items():
key = key.lstrip("sk_heads")
if key.startswith("."):
key = key[1:]
if key == "task_heads.qa_head.qa_outputs.weight":
key = "qa_outputs.weight"
elif key == "task_heads.qa_head.qa_outputs.bias":
key = "qa_outputs.bias"
if key not in required_keys:
continue
state_dict_new[key] = val
model.load_state_dict(state_dict=state_dict_new, strict=False)
return model
[{'AccuracyMetric': {'total_items': 498,
'empty_items': 0,
'accuracy': {'score': 0.8749428449729656}},
'ExactMatchMetric': {'total_items': 1486,
'empty_items': 0,
'exact_match': 0.6292059219380888,
'flattened': True},
'F1Metric': {'total_items': 498,
'empty_items': 0,
'f1': {'score': 0.8959647784007141}}}]
[{'AccuracyMetric': {'total_items': 4652,
'empty_items': 0,
'accuracy': {'score': 0.8999344308604981}},
'ExactMatchMetric': {'total_items': 4652,
'empty_items': 0,
'exact_match': 0.7818142734307825,
'flattened': True},
'F1Metric': {'total_items': 4652,
'empty_items': 0,
'f1': {'score': 0.9333546560845649}}}]
With deepset/roberta-base-squad2
model trained on SQUAD-v2, the metrics are
validation
[{'AccuracyMetric': {'total_items': 498,
'empty_items': 0,
'accuracy': {'score': 0.8984785462195101}},
'ExactMatchMetric': {'total_items': 1486,
'empty_items': 0,
'exact_match': 0.6608344549125168,
'flattened': True},
'F1Metric': {'total_items': 498,
'empty_items': 0,
'f1': {'score': 0.9145569508553953}}}]
train
[{'AccuracyMetric': {'total_items': 4652,
'empty_items': 0,
'accuracy': {'score': 0.900906346324613}},
'ExactMatchMetric': {'total_items': 4652,
'empty_items': 0,
'exact_match': 0.7818142734307825,
'flattened': True},
'F1Metric': {'total_items': 4652,
'empty_items': 0,
'f1': {'score': 0.9349720046558909}}}]
cc: @muthukumaranR @xhagrg
Using a full validation set using train-watbertv6-squad-2ep
model:
[{'AccuracyMetric': {'total_items': 5928,
'empty_items': 0,
'accuracy': {'score': 0.8915889391748646}},
'ExactMatchMetric': {'total_items': 20302,
'empty_items': 0,
'exact_match': 0.6127967687912521,
'flattened': True},
'F1Metric': {'total_items': 5928,
'empty_items': 0,
'f1': {'score': 0.9091273137464949}}}]
Using DistilBertForQuestionAnswering
on same dataset (default QuestionAnsweringHFPipelineWrapper()
)
[{'AccuracyMetric': {'total_items': 5928,
'empty_items': 0,
'accuracy': {'score': 0.8492370725944365}},
'ExactMatchMetric': {'total_items': 20302,
'empty_items': 0,
'exact_match': 0.5774307949955669,
'flattened': True},
'F1Metric': {'total_items': 5928,
'empty_items': 0,
'f1': {'score': 0.8677135502536183}}}]
Using RobertaModelForQuestionAnsswering
wrapped_model = QuestionAnsweringHFPipelineWrapper(
model="deepset/roberta-base-squad2",
predictions_postprocessor=lambda xs: list(map(lambda x: x["answer"], xs)),
)
[{'AccuracyMetric': {'total_items': 5928,
'empty_items': 1,
'accuracy': {'score': 0.8961578651210724}},
'ExactMatchMetric': {'total_items': 20302,
'empty_items': 4,
'exact_match': 0.6172036653857523,
'flattened': True},
'F1Metric': {'total_items': 5928,
'empty_items': 1,
'f1': {'score': 0.9126484736361342}}}]
using the NASA v6 train-watbertv6-squad-2ep
{'AccuracyMetric': {'accuracy': {'score': 0.8707644373263648},
'empty_items': 0,
'total_items': 86821},
'ExactMatchMetric': {'empty_items': 0,
'exact_match': 0.7325992559403831,
'flattened': True,
'total_items': 86821},
'F1Metric': {'empty_items': 0,
'f1': {'score': 0.9208171457265248},
'total_items': 86821}}
train-watbertv6-squad-2ep
RobertaModelForDownstreamTasks
model
[{'AccuracyMetric': {'total_items': 5928,
'empty_items': 0,
'accuracy': {'score': 0.8915889391748646}},
'ExactMatchMetric': {'total_items': 20302,
'empty_items': 0,
'exact_match': 0.6127967687912521,
'flattened': True},
'F1Metric': {'total_items': 5928,
'empty_items': 0,
'f1': {'score': 0.909127313746495}}}]
{'AccuracyMetric': {'total_items': 86821,
'empty_items': 0,
'accuracy': {'score': 0.8707644373263648}},
'ExactMatchMetric': {'total_items': 86821,
'empty_items': 0,
'exact_match': 0.7325992559403831,
'flattened': True},
'F1Metric': {'total_items': 86821,
'empty_items': 0,
'f1': {'score': 0.9208171457265248}}}
Re: The code to evaluate is something like:
from primeqa.mrc.models.task_model import ModelForDownstreamTasks
from primeqa.mrc.models.heads.extractive import EXTRACTIVE_HEAD
from evalem.evaluators import QAEvaluator
from evalem.pipelines import SimpleEvaluationPipeline
from evalem.models import QuestionAnsweringHFPipelineWrapper
from evalem.misc.datasets import get_squad_v2
config = AutoConfig.from_pretrained(<path_to_config_json>)
model = ModelForDownstreamTasks.from_config(
config,
pretrained_model_name_or_path=<path_to_pth_bin>,
task_heads=EXTRACTIVE_HEAD
)
model.set_task_head("qa_head")
tokenizer = AutoTokenizer.from_pretrained(<path_to_tokenizer>)
wrapped_model = QuestionAnsweringHFPipelineWrapper(
model=model,
tokenizer=tokenizer,
device="mps"
)
data = get_squad_v2("validation", nsamples=None)
# sort to optimize the inference
inputs, references = zip(*sorted(zip(data["inputs"], data["references"]), key=lambda x: len(x[0]["context"]+x[0]["question"])))
data = dict(inputs=inputs, references=references)
evaluators = QAEvaluator()
eval_pipe = SimpleEvaluationPipeline(
model=wrapped_model,
evaluators=evaluators
)
results = eval_pipe(
data["inputs"],
data["references"],
model_params=dict(batch_size=64)
)
Using the V6 model, and doing barebone QA pipeline:
[{'AccuracyMetric': {'total_items': 117,
'empty_items': 0,
'accuracy': {'score': 0.6485626505207466}},
'ExactMatchMetric': {'total_items': 117,
'empty_items': 0,
'exact_match': 0.3162393162393162,
'flattened': True},
'F1Metric': {'total_items': 117,
'empty_items': 0,
'f1': {'score': 0.752370960575804}}}]
Using only the is_impossible
true questions/context, the classification report is
'empty_items': 0,
'accuracy': {'score': 0.0}},
'F1Metric': {'total_items': 78, 'empty_items': 0, 'f1': {'score': 0.0}},
'PrecisionMetric': {'total_items': 78,
'empty_items': 0,
'precision': {'score': 0.0}},
'RecallMetric': {'total_items': 78,
'empty_items': 0,
'recall': {'score': 0.0}},
'ConfusionMatrix': {'confusion_matrix': array([[ 0, 0],
[78, 0]]),
'labels': ['False', 'True'],
'flattened': True,
'total_items': 78,
'empty_items': 0}}]
This specific was done through a new custom Classifier
class QAClf(QuestionAnsweringHFPipelineWrapper):
def _predict(self, inputs, **kwargs):
accum = []
for inp in inputs:
res = []
for preprocessed in self.pipeline.preprocess(inp):
model_outputs = self.pipeline.forward(preprocessed)
res.append(model_outputs["is_impossible"])
accum.append(any(res))
return accum
Adding plausible_answers
along with answers
to the references for each question-context, we have:
[{'AccuracyMetric': {'total_items': 139,
'empty_items': 0,
'accuracy': {'score': 0.5858646795838852}},
'ExactMatchMetric': {'total_items': 139,
'empty_items': 0,
'exact_match': 0.26618705035971224,
'flattened': True},
'F1Metric': {'total_items': 139,
'empty_items': 0,
'f1': {'score': 0.6872523842725322}}}]
Using only the plausible_answers
set:
[{'AccuracyMetric': {'total_items': 22,
'empty_items': 0,
'accuracy': {'score': 0.25242547051057684}},
'ExactMatchMetric': {'total_items': 22,
'empty_items': 0,
'exact_match': 0.0,
'flattened': True},
'F1Metric': {'total_items': 22,
'empty_items': 0,
'f1': {'score': 0.34092211879481804}}}]
[{'BertScore': {'total_items': 22,
'empty_items': 0,
'bertscore': {'score': 0.5337053449316458,
'precision': 0.5626290399919857,
'recall': 0.53028473935344,
'f1': 0.5337053449316458,
'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.27.4)'}},
'RougeMetric': {'total_items': 22,
'empty_items': 0,
'rouge': {'rouge1': 0.2746537599478776,
'rouge2': 0.194448408734123,
'rougeL': 0.2734842019601378,
'rougeLsum': 0.2784270397906761}},
'MeteorMetric': {'total_items': 22,
'empty_items': 0,
'meteor': {'score': 0.25063900377028847}}}]
Here's the prediction dump for these 22 predictions: es_prediction_plausible_answers.csv
Using primeqa.mrc.trainers.mrc.MRCTrainer
on 117 total train set and evaluating on the train set itself, we have significantly better performance
[{'AccuracyMetric': {'total_items': 117,
'empty_items': 0,
'accuracy': {'score': 0.7754180480704133}},
'ExactMatchMetric': {'total_items': 117,
'empty_items': 0,
'exact_match': 0.39316239316239315,
'flattened': True},
'F1Metric': {'total_items': 117,
'empty_items': 0,
'f1': {'score': 0.8674849685441913}}},
{'BertScore': {'total_items': 117,
'empty_items': 0,
'bertscore': {'score': 0.8790242371396122,
'precision': 0.9054728822830396,
'recall': 0.8679911741334149,
'f1': 0.8790242371396122,
'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.24.0)'}},
'RougeMetric': {'total_items': 117,
'empty_items': 0,
'rouge': {'rouge1': 0.8215218996507581,
'rouge2': 0.733657321947087,
'rougeL': 0.822142299088887,
'rougeLsum': 0.8234616367561287}},
'MeteorMetric': {'total_items': 117,
'empty_items': 0,
'meteor': {'score': 0.7719667782949707}}}]
cc: @xhagrg @muthukumaranR
Here, we fine-tune the V6 model on only unasnwerable question (start/end indices are -1
). This is to check if model can learn unanswerable questions (TargetType.NO_ANSWER
). We use whole 78 unanswerable question/context for the same. We fine-tune for 25 epochs. And after inferencing on the same 78 data points, we found that the classification accuracy is only 16.67%
Counter({'TargetType.SPAN_ANSWER': 65, 'TargetType.NO_ANSWER': 13})
Evalem result
[{'AccuracyMetric': {'total_items': 117,
'empty_items': 0,
'accuracy': {'score': 0.7898480621066585}},
'ExactMatchMetric': {'total_items': 117,
'empty_items': 0,
'exact_match': 0.49572649572649574,
'flattened': True},
'F1Metric': {'total_items': 117,
'empty_items': 0,
'f1': {'score': 0.8693825445385333}}},
{'BertScore': {'total_items': 117,
'empty_items': 0,
'bertscore': {'score': 0.8698257719859098,
'precision': 0.8987084087143596,
'recall': 0.856557988967651,
'f1': 0.8698257719859098,
'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.24.0)'}},
'BartScore': {'bartscore': {'score': -2.0845768372727256,
'model_checkpoint': 'bartscore-large-cnn',
'model_weights': None,
'total_items': 117,
'flattened': True}},
'RougeMetric': {'total_items': 117,
'empty_items': 0,
'rouge': {'rouge1': 0.8129755845780602,
'rouge2': 0.7172201468956072,
'rougeL': 0.8098182612872958,
'rougeLsum': 0.811912951013968}},
'MeteorMetric': {'total_items': 117,
'empty_items': 0,
'meteor': {'score': 0.7576563081672764}}}]
Using the vanilla v6 model, the classification accuracy was 0%.
Classification accuracy of 6.4%
.
Evalem result
[{'AccuracyMetric': {'total_items': 117,
'empty_items': 0,
'accuracy': {'score': 0.8461192003756203}},
'ExactMatchMetric': {'total_items': 117,
'empty_items': 0,
'exact_match': 0.49572649572649574,
'flattened': True},
'F1Metric': {'total_items': 117,
'empty_items': 0,
'f1': {'score': 0.9094455238860134}}},
{'BertScore': {'total_items': 117,
'empty_items': 0,
'bertscore': {'score': 0.925027038806524,
'precision': 0.95449644072443,
'recall': 0.9068491787482531,
'f1': 0.925027038806524,
'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.24.0)'}},
'BartScore': {'bartscore': {'score': -1.8274125056898491,
'model_checkpoint': 'bartscore-large-cnn',
'model_weights': None,
'total_items': 117,
'flattened': True}},
'RougeMetric': {'total_items': 117,
'empty_items': 0,
'rouge': {'rouge1': 0.8809201723626585,
'rouge2': 0.7944845266233733,
'rougeL': 0.8778573242297767,
'rougeLsum': 0.8815012393897714}},
'MeteorMetric': {'total_items': 117,
'empty_items': 0,
'meteor': {'score': 0.8142538435596486}}}]
cc: @muthukumaranR
We fine-tune NASA v6 model on the 113k squad v2 dataset to learn impossible answers.
Frozen: both Roberta + LLM Decoder
Trainable: Classification head
Using the fine-tuned model, we test the accuracy of unanswerable question on 78 ES samples that has is_impossible=True
.
For other evalem metrics, we use 117 answerable samples
checkpoint | imp_count | impossibilities | AccuracyMetric | ExactMatchMetric | F1Metric | BertScore | BartScore | rouge1 | rouge2 | rougeL | rougeLsum | MeteorMetric |
---|---|---|---|---|---|---|---|---|---|---|---|---|
checkpoint-679 | 18 | 0.23076923076923078 | 0.7237731120382738 | 0.37209302325581395 | 0.8267948424845446 | 0.8407947708700978 | -3.546640660772976 | 0.7687902386496759 | 0.6438440333129436 | 0.7704737196647233 | 0.770284747750026 | 0.6873938086854574 |
checkpoint-1359 | 14 | 0.1794871794871795 | 0.7492086639099467 | 0.38271604938271603 | 0.8504774277786713 | 0.8531065649456449 | -3.6118257450751767 | 0.7922125739112171 | 0.6794054615118694 | 0.7951915301959707 | 0.793547946796979 | 0.7174579109342103 |
checkpoint-2039 | 12 | 0.15384615384615385 | 0.7506705464023693 | 0.3625 | 0.8495404239398365 | 0.8549549337476492 | -3.657258931133482 | 0.7968581258350768 | 0.6912613231270669 | 0.7977621006730667 | 0.7984314509784622 | 0.7251602700722106 |
checkpoint-2719 | 12 | 0.15384615384615385 | 0.7445566306364579 | 0.36585365853658536 | 0.8411176666487777 | 0.8507185727357864 | -3.6051308423535438 | 0.7900940365512097 | 0.6882017598899468 | 0.7927470161820624 | 0.7935478891918374 | 0.7189063610460591 |
checkpoint-3395 | 12 | 0.15384615384615385 | 0.7423968573795459 | 0.36904761904761907 | 0.8401860227787743 | 0.8502603943149248 | -3.557788890396428 | 0.7875391062055703 | 0.6765373472484382 | 0.7902922212189262 | 0.792395070897573 | 0.7116466365898267 |
The epoch-1 checkpoint (checkpoint-679) yields following evalem metrics on squad_v2 heldout sets
[{'AccuracyMetric': {'total_items': 5928,
'empty_items': 1,
'accuracy': {'score': 0.8868717802639483}},
'ExactMatchMetric': {'total_items': 20302,
'empty_items': 4,
'exact_match': 0.6030150753768844,
'flattened': True},
'F1Metric': {'total_items': 5928,
'empty_items': 1,
'f1': {'score': 0.9054936590556527}}},
{'BertScore': {'total_items': 5928,
'empty_items': 1,
'bertscore': {'score': 0.9359343827188126,
'precision': 0.9382273859236159,
'recall': 0.9395118997441249,
'f1': 0.9359343827188126,
'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.28.1)'}},
'BartScore': {'bartscore': {'score': -2.58893137244553,
'model_checkpoint': 'bartscore-large-cnn',
'model_weights': None,
'total_items': 20302,
'flattened': True}},
'RougeMetric': {'total_items': 5928,
'empty_items': 1,
'rouge': {'rouge1': 0.9057619208070267,
'rouge2': 0.5958566106015495,
'rougeL': 0.9051587096760276,
'rougeLsum': 0.9054059014325857}},
'MeteorMetric': {'total_items': 5928,
'empty_items': 1,
'meteor': {'score': 0.7401154836791255}}}]
checkpoint | imp_count | impossibilities | AccuracyMetric | ExactMatchMetric | F1Metric | BertScore | BartScore | rouge1 | rouge2 | rougeL | rougeLsum | MeteorMetric |
---|---|---|---|---|---|---|---|---|---|---|---|---|
checkpoint-1 | 0 | 0.000000 | 0.648563 | 0.316239 | 0.752371 | 0.799980 | -2.935075 | 0.699827 | 0.586938 | 0.697715 | 0.700365 | 0.620332 |
checkpoint-3 | 0 | 0.000000 | 0.752506 | 0.382979 | 0.849651 | 0.858997 | -3.200798 | 0.800872 | 0.671984 | 0.794605 | 0.797221 | 0.721241 |
checkpoint-5 | 0 | 0.000000 | 0.822445 | 0.445783 | 0.895975 | 0.902979 | -3.212394 | 0.864582 | 0.749611 | 0.861217 | 0.860737 | 0.799884 |
checkpoint-7 | 0 | 0.000000 | 0.777219 | 0.405660 | 0.870073 | 0.876644 | -2.568347 | 0.821760 | 0.718866 | 0.818622 | 0.817586 | 0.762200 |
checkpoint-9 | 0 | 0.000000 | 0.812671 | 0.427273 | 0.895768 | 0.897784 | -2.244789 | 0.855293 | 0.766594 | 0.854760 | 0.854608 | 0.799363 |
checkpoint-10 | 0 | 0.000000 | 0.813520 | 0.427273 | 0.895905 | 0.896490 | -2.238462 | 0.854913 | 0.761107 | 0.855206 | 0.853965 | 0.798013 |
checkpoint-12 | 4 | 0.051282 | 0.838136 | 0.448598 | 0.912923 | 0.907513 | -2.247307 | 0.875378 | 0.780678 | 0.873396 | 0.874263 | 0.815640 |
checkpoint-14 | 59 | 0.756410 | 0.847073 | 0.472222 | 0.917019 | 0.915354 | -2.159167 | 0.882529 | 0.784328 | 0.882285 | 0.880711 | 0.817893 |
checkpoint-16 | 60 | 0.769231 | 0.837538 | 0.459459 | 0.911673 | 0.913235 | -2.111236 | 0.874106 | 0.780827 | 0.873819 | 0.872260 | 0.808995 |
checkpoint-18 | 47 | 0.602564 | 0.837957 | 0.460177 | 0.902840 | 0.916226 | -2.040956 | 0.870217 | 0.795396 | 0.869844 | 0.870228 | 0.807535 |
checkpoint-19 | 49 | 0.628205 | 0.831072 | 0.447368 | 0.899297 | 0.913466 | -2.035210 | 0.865367 | 0.788855 | 0.864640 | 0.864606 | 0.803030 |
checkpoint-21 | 58 | 0.743590 | 0.834757 | 0.452174 | 0.901473 | 0.912336 | -1.974991 | 0.870022 | 0.798309 | 0.869298 | 0.868682 | 0.807074 |