Spico197 / DocEE

🕹️ A toolkit for document-level event extraction, containing some SOTA model implementations.
https://doc-ee.readthedocs.io/
MIT License
234 stars 36 forks source link

Failed to reproduce the result with inference.py #55

Closed Shiina18 closed 2 years ago

Shiina18 commented 2 years ago

Idea sharing While sharing what you want to do, make sure to protect your ideas.

Problems Used the setting and task dump given in the readme, and inference.py as follows, but got a low result: f1 0.6856. In detail, the precision is as expected, but the recall is way low.

Others The code below is rather casual and informal, but should work. It concats all sentences in one doc as a string, then run dee_task.predict_one, and finally measure it with dee_metric.py.

import os

from dee.tasks import DEETask, DEETaskSetting

if __name__ == "__main__":
    # init
    task_dir = "Exps/sct-Tp1CG-with_left_trigger-OtherType-comp_ents-bs64_8"
    cpt_file_name = "TriggerAwarePrunedCompleteGraph"
    # bert_model_dir is for tokenization use, `vocab.txt` must be included in this dir
    # change this to `bert-base-chinese` to use the huggingface online cache
    bert_model_dir = "bert-base-chinese"

    # load settings
    dee_setting = DEETaskSetting.from_pretrained(
        os.path.join(task_dir, f"{cpt_file_name}.task_setting.json")
    )
    dee_setting.local_rank = -1
    dee_setting.filtered_data_types = "o2o,o2m,m2m,unk"
    dee_setting.bert_model = bert_model_dir

    # build task
    dee_task = DEETask(
        dee_setting,
        load_train=False,
        load_dev=False,
        load_test=False,
        load_inference=False,
        parallel_decorate=False,
    )

    # load PTPCG parameters
    dee_task.resume_cpt_at(57)

    import collections
    import copy
    import json
    import pathlib

    import dee_metric

    EVENT2SCHEMA ={'EquityOverweight': {'EquityHolder': '', 'TradedShares': '', 'StartDate': '', 'EndDate': '', 'LaterHoldingShares': '', 'AveragePrice': ''}, 'EquityUnderweight': {'EquityHolder': '', 'TradedShares': '', 'StartDate': '', 'EndDate': '', 'LaterHoldingShares': '', 'AveragePrice': ''}, 'EquityFreeze': {'EquityHolder': '', 'FrozeShares': '', 'LegalInstitution': '', 'TotalHoldingShares': '', 'TotalHoldingRatio': '', 'StartDate': '', 'EndDate': '', 'UnfrozeDate': ''}, 'EquityRepurchase': {'CompanyName': '', 'RepurchasedShares': '', 'HighestTradingPrice': '', 'LowestTradingPrice': '', 'ClosingDate': '', 'RepurchaseAmount': ''}, 'EquityPledge': {'Pledger': '', 'PledgedShares': '', 'Pledgee': '', 'TotalHoldingShares': '', 'TotalHoldingRatio': '', 'TotalPledgedShares': '', 'StartDate': '', 'EndDate': '', 'ReleasedDate': ''}}

    type2order = {}
    for t, d in EVENT2SCHEMA.items():
        type2order[t] = list(d.keys())

    def type2min(t):
        if t in {"EquityOverweight", "EquityUnderweight", "EquityRepurchase"}:
            return 4
        return 5

    stats = []
    tp = 0
    fp = 0
    fn = 0
    path = pathlib.Path(__file__).parent / 'Data' / 'test.json'
    data = json.loads(path.read_text(encoding='utf8'))
    results = []
    for datum in data:
        d = datum[1]
        sents = d['sentences']
        doc = ''.join(sents)
        events = d['recguid_eventname_eventdict_list']
        type2events = collections.defaultdict(list)
        for event in events:
            t = event[1]
            event_dict = event[2]
            cur_event = []
            for role in type2order[t]:
                cur_event.append(event_dict.get(role))
            type2events[t].append(copy.deepcopy(cur_event))

        pred_results = dee_task.predict_one(doc)
        results.append(pred_results)
        pred_type2events = collections.defaultdict(list)
        pred_events = pred_results['event_list']
        for pred_event in pred_events:
            t = pred_event['event_type']
            cur_event = []
            for role in type2order[t]:
                for arg in pred_event['arguments']:
                    if arg['role'] == role:
                        cur_event.append(arg['argument'])
                        break
                else:
                    cur_event.append(None)

            # num = sum(1 for a in cur_event if a)
            # if num >= type2min(t):
            pred_type2events[t].append(copy.deepcopy(cur_event))

        for t, schema in EVENT2SCHEMA.items():
            num_roles = len(schema)
            pred_records = pred_type2events[t]
            gold_records = type2events[t]
            stat = dee_metric.agg_event_level_tpfpfn_stats(pred_records, gold_records, num_roles)
            tp += stat[0]
            fp += stat[1]
            fn += stat[2]
            precision = tp / (tp + fp)
            recall = tp / (tp + fn)
            f1 = (2 * precision * recall) / (precision + recall)
            print(f1)
            print(tp, fp, fn)
            stats.append(stat)

    print('final f1', f1)
    save_path = pathlib.Path(__file__).parent / 'result.json'
    save_path.write_text(json.dumps(results, ensure_ascii=False), encoding='utf8')
Spico197 commented 2 years ago

Hi there~ Sorry for the late response.

We provided inference.py as an example to explain how the model could be used to generate event records, since a lot of people requested this feature. So it's actually not a quite right way to reproduce the results as reported in our paper.

Inside dee_task.predict_one, when calling convert_string_to_raw_input, the concatenated string will be split into sentences by the sent_seg function in dee/helper/__init__.py, which may not provide the same sentences as in the original dataset. The results you mentioned is quite interesting though, we didn't try this before. If you want to reproduce the results as reported in our paper, we suggest following our instruction in README.md to re-train the model, or load the test set and call the default dee_task.eval function directly.

In addition, if you hacked the source code to bypass the sent_seg, the upper bound of entity extraction will decrease rapidly since the default maximal sequence length is 128, leading to a performance decline.

If you have further questions, feel free to leave a message.

Shiina18 commented 2 years ago

Thanks for your reply. I didn't hack the dee_task.predict_one function (so sent_seg also applies as usual), but just imitated the logic and replaced inference.py with the code above. My result can be reproduced in a few minutes. I will take some time to see how dee_task.eval works and goes.