allenai / SciREX

Data/Code Repository for https://api.semanticscholar.org/CorpusID:218470122
Apache License 2.0
129 stars 30 forks source link

Key error in predict_salient_mentions.py #13

Open viswavi opened 4 years ago

viswavi commented 4 years ago

I trained the baseline SciREX model, and then tried to make predictions for it, but ran into a problem in the step to predict salient mentions. When I tried running:

python scirex/predictors/predict_salient_mentions.py $scirex_archive $test_output_folder/ner_predictions.jsonl $test_output_folder/salient_mentions_predictions.jsonl $cuda_device

I got the error:

  span_ix = span_mask.view(-1).nonzero().squeeze(1).long()
36it [00:02, 15.88it/s]
Traceback (most recent call last):
  File "scirex/predictors/predict_salient_mentions.py", line 92, in <module>
    main()
  File "scirex/predictors/predict_salient_mentions.py", line 88, in main
    predict(archive_folder, test_file, output_file, cuda_device)
  File "scirex/predictors/predict_salient_mentions.py", line 57, in predict
    metadata = output_res['metadata']
KeyError: 'metadata'

It seems that one in every so often, the output of the salient mention prediction decoder is missing the metadata field.

As a stopgap solution, I'm just skipping batches with this issue, but am concerned that I'll accidentally bias my evaluation somehow by doing this.

successar commented 4 years ago

Did you reduce the batch size to train the model?

viswavi commented 4 years ago

I did. Reduced batch size to 2 (but training 19 epochs still completed after 3 days without complaint)

successar commented 4 years ago

This is fine. The reason that output doesn't contain any metadata is because the input doesn't actually have any predicted spans, so you can safely ignore them.

Laxmaan commented 4 years ago

I'm getting the same error later on in predict_n_ary_relations.py. I used a batch size of 4 for training. I implemented the fix for the salient mentions as shown above but I run into an error at this stage as well.

Traceback (most recent call last): File "scirex/predictors/predict_n_ary_relations.py", line 109, in <module> predict(argv[1], argv[2], argv[3], argv[4], int(argv[5])) File "scirex/predictors/predict_n_ary_relations.py", line 77, in predict output_res = model.decode_relations(batch) File "/storage/home/lpb5347/scratch/scirex/SciREX/scirex/models/scirex_model.py", line 385, in decode_relations res["n_ary_relation"] = self._cluster_n_ary_relation.decode(output_n_ary_relation) File "/storage/home/lpb5347/scratch/scirex/SciREX/scirex/models/relations/entity_relation.py", line 211, in decode "metadata" : output_dict['metadata']

jeremyadamsfisher commented 3 years ago

Is there a resolution to this? I'm struggling with this issue in predict_n_ary_relations.py with a batch size of 1

viswavi commented 3 years ago

@jeremyadamsfisher a stopgap resolution is to just skip these documents here (here's how I did it in my branch)

However, even with this fix, this gives quite different results on the relation prediction metrics for "End-to-End (gold salient clustering)", which may be an issue if you care about this metric. I think the code does not currently match their paper for this particular evaluation.

With this fix, here's the results that I was able to reproduce vs the originally reported results: scirex_reproduced

viswavi commented 3 years ago

On an unrelated note, I added a few metrics (particularly for relation extraction) as well as significance testing scripts in my branch here: https://github.com/viswavi/SciREX, if you find it useful.

jeremyadamsfisher commented 3 years ago

Thanks @viswavi!