DreamInvoker / GAIN

Source code for EMNLP 2020 paper: Double Graph Based Reasoning for Document-level Relation Extraction
MIT License
142 stars 31 forks source link

Training with custom data #6

Closed alejandrojcastaneira closed 3 years ago

alejandrojcastaneira commented 3 years ago

Hello,

I have put my dataset in the DocRed format and also created the corresponding filestrain_annotated.json, dev.json, test.json, ner2id.json and rel2id.json to train a BERT type architecture. However, in my dataset, the number of relationships and entities is different than what is used in DocRed. I would like to know which files/parameters I would need to modify in order to be able to train with custom data.

Best regards

DreamInvoker commented 3 years ago

The configuration is set in this file.

You could use --relation_nums in the shell script to change your nums of relation types for your custom dataset.

Best

alejandrojcastaneira commented 3 years ago

Hi, thanks for the fast reply.

I did it as you suggested, I was able to read and start training on my custom data. I am getting this at the end of the first iteration.

  File "train.py", line 231, in <module>
    train(opt)
  File "train.py", line 188, in train
    ign_f1, ign_auc, pr_x, pr_y = test(model, dev_loader, model_name, id2rel=id2rel)
  File "/home/ale/PycharmProjects/GAIN/code/test.py", line 23, in test
    for cur_i, d in enumerate(dataloader):
  File "/home/ale/PycharmProjects/GAIN/code/data.py", line 687, in __iter__
    ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1])
IndexError: index 1722 is out of bounds for dimension 1 with size 1722

I assume this' it's when the evaluation part after each iteration starts.

DreamInvoker commented 3 years ago

1722, which is the max number of entity pairs per document in DocRED, is a hard code. You could change it to make it bigger for your datasets in the line.

alejandrojcastaneira commented 3 years ago

Thank you very much, it works, I did a quick experiment training on the same training, validation and test sets. I am getting these results:

2020-10-27 10:12:59.926594 ALL  : Theta 0.9957 | F1 0.8710 | AUC 0.9123
2020-10-27 10:13:00.174016 Ignore ma_f1 0.0000 | inhput_theta 0.9957 test_result P 0.0000 test_result R 0.8923 test_result F1 0.0000 | AUC 0.0000
2020-10-27 10:13:00.221839 | epoch  50 | time: 12.11s

It is normal behaviour that the test results f1 and AUC are equal to 0 or I could have some problem in the creation of the .json data?

If I would like to manually load the model and check the predictions, for example making the inference on a single sample, within a python script, then which would be the necessary steps.

DreamInvoker commented 3 years ago

It seems that the F1 score, which is 0.8710, is OK, and the Ignore F1 score is zero, due to the zero value of the Ignore Precision. It may be the reason that the correct and correct_in_train variables here are the same after the for-loop, Are you using the same dataset when training and evaluating? Ignore F1 metrics will exclude the relation instances shared by training set and dev/test set when computing Ignore Precision as described in DocRED original paper.

alejandrojcastaneira commented 3 years ago

Yes, correctly.

I was using the same samples in training and validation as initial experiment, before I'll properly define the training and evaluation sets, that must be the reason of the Ignore Precision numbers equal 0 because it's excluding all relations then.

DreamInvoker commented 3 years ago

Yes, it is okay now.

Good luck with your experiments!

alejandrojcastaneira commented 3 years ago

Thank you! I would like to know how to make inference on a single sample within a python script, where to read the last saved checkpoint? what inputs would be needed? if there is an example available on this would be great!

Maybe I'm asking to many questions :-)

DreamInvoker commented 3 years ago

Save directory of the best checkpoint is configured in config.py. You could use test.py code (in main code block) to make inference on your sample, which can be transformed into dev.json as in the dataset, and the input format is the same as dev.json.