amazon-science / wqa_tanda

This repo provides code and data used in our TANDA paper.
Other
108 stars 26 forks source link

Reproducing roberta-large #9

Open cosmoshsv opened 3 years ago

cosmoshsv commented 3 years ago

I have used data preprocessing code from run_glue, and MAP/MRR metrics from the author in issue #8, however, the obtained numbers from the pretrained models ASNQ -> WikiQA are slightly higher than the MAP/MRR reported in the paper. Can the evaluation code be pushed into the repo (preprocessing + inference) ?

cxx-cindy commented 3 years ago

I have used data preprocessing code from run_glue, and MAP/MRR metrics from the author in issue #8, however, the obtained numbers from the pretrained models ASNQ -> WikiQA are slightly higher than the MAP/MRR reported in the paper. Can the evaluation code be pushed into the repo (preprocessing + inference) ?

excuse me,can you run the code successfully?

cosmoshsv commented 3 years ago

Yes the code runs successfully. I've also set up a baseline with Roberta-large with few steps of wikiQA training to warm start the model and that reports higher MAP/MRR too.

cxx-cindy commented 3 years ago

Wow, you're really great

liudonglei commented 2 years ago

Yes the code runs successfully. I've also set up a baseline with Roberta-large with few steps of wikiQA training to warm start the model and that reports higher MAP/MRR too.

hi, I want to know how to use this code https://github.com/alexa/wqa_tanda/issues/8#issuecomment-853225088 in a individual pyton file or append to some file?

lucadiliello commented 2 years ago

I was also able to obtain better results than the ones in the paper. For example by increasing the transfer batch size to 2048. I think the results in the paper are a like a lower bound of what can be obtained with this method.

cosmoshsv commented 2 years ago

The mapk and mrr code can be called as - mapk(questions, flat_true_labels, flat_predictions), where questions is an array, flat_true_labels is a flattened list of labels, and flat_predictions is a flattened list of np.argmax(logits, axis=1). In my code, this was appended into the inference notebook.