goodmami / penman

PENMAN notation (e.g. AMR) in Python
https://penman.readthedocs.io/
MIT License
139 stars 27 forks source link

Different triple order = low smatch score. How to avoid automatic inversion? #117

Closed BramVanroy closed 1 year ago

BramVanroy commented 1 year ago

I am verifying whether a linearization + subword tokenization is non-destructive, i.e. whether the original graph can be retrieved after the process. While doing so, I found that the order of triplets matter in how the graph is created.

In the example below you find two equal graphs (graph_a == graph_b) that only differ in the order of their triplets. However, the implication is that during encoding, they end up as different graphs

# graph a
(k / kill-01
   :ARG0 (h / hunger-01
            :ARG0 (w / we))
   :ARG1 w)
# graph b
(k / kill-01
   :ARG0 (h / hunger-01
            :ARG0 (w / we
                     :ARG1-of k)))

While these graphs are identical in their underlying representation, their surface is evidently different. I would not take issue with it until I found that this leads to a low smatch score of 0.8571! I think it is reasonable to assume that graph_a == graph_b would also yield a smatch score of 1 - but perhaps that is more an issue (or deliberate choice) of the metric.

Regardless, I am trying to find out whether such an automatic inversion with -of can be disabled? I tried to noop model but that did not seem to have an effect. A full example, including smatch calculation, below. I'm thankful for any thoughts how to deal with the issue!

import penman
from penman import Graph
from penman.models.noop import model as noop_model
from typing import List

import smatch

def calculate_smatch(refs_penman: List[str], preds_penman: List[str]):
    total_match_num = total_test_num = total_gold_num = 0
    n_invalid = 0

    for sentid, (ref_penman, pred_penman) in enumerate(zip(refs_penman, preds_penman), 1):
        best_match_num, test_triple_num, gold_triple_num = smatch.get_amr_match(
            ref_penman, pred_penman, sent_num=sentid
        )

        total_match_num += best_match_num
        total_test_num += test_triple_num
        total_gold_num += gold_triple_num
        # clear the matching triple dictionary for the next AMR pair
        smatch.match_triple_dict.clear()

    score = smatch.compute_f(total_match_num, total_test_num, total_gold_num)

    return {
        "smatch_precision": score[0],
        "smatch_recall": score[1],
        "smatch_fscore": score[2],
        "ratio_invalid_amrs": n_invalid / len(preds_penman) * 100,
    }

if __name__ == '__main__':
    triples_a = [('k', ':instance', 'kill-01'),
                 ('k', ':ARG0', 'h'),
                 ('h', ':instance', 'hunger-01'),
                 ('h', ':ARG0', 'w'),
                 ('k', ':ARG1', 'w'),
                 ('w', ':instance', 'we')]
    triples_b = [('k', ':instance', 'kill-01'),
                 ('k', ':ARG0', 'h'),
                 ('h', ':instance', 'hunger-01'),
                 ('h', ':ARG0', 'w'),
                 ('w', ':instance', 'we'),  # switched with below
                 ('k', ':ARG1', 'w')]

    graph_a = Graph(triples_a)
    graph_b = Graph(triples_b)
    print("equal?", graph_a == graph_b)
    # Tried with both default and noop model
    print(penman.encode(graph_a, model=noop_model))
    print(penman.encode(graph_b, model=noop_model))
    print(calculate_smatch([penman.encode(graph_a, model=noop_model)], [penman.encode(graph_b, model=noop_model)]))
BramVanroy commented 1 year ago

As some extra information for future readers. After a lot more digging, I found that smatch can give low scores even for the same graph! In the example below, using the same graph as reference and prediction, the smatch score is only 0.92... Unless my calculate_smatch function is incorrect, but I do not think so.

What is worse: it does not seem deterministic! Sometimes I get a score of 0.9, then 0.92, then 0.8733.

from typing import List

import smatch

def calculate_smatch(refs_penman: List[str], preds_penman: List[str]):
    total_match_num = total_test_num = total_gold_num = 0
    n_invalid = 0

    for sentid, (ref_penman, pred_penman) in enumerate(zip(refs_penman, preds_penman), 1):
        best_match_num, test_triple_num, gold_triple_num = smatch.get_amr_match(
            ref_penman, pred_penman, sent_num=sentid
        )

        total_match_num += best_match_num
        total_test_num += test_triple_num
        total_gold_num += gold_triple_num
        # clear the matching triple dictionary for the next AMR pair
        smatch.match_triple_dict.clear()

    score = smatch.compute_f(total_match_num, total_test_num, total_gold_num)

    return {
        "smatch_precision": score[0],
        "smatch_recall": score[1],
        "smatch_fscore": score[2],
        "ratio_invalid_amrs": n_invalid / len(preds_penman) * 100,
    }

s = """(r / result-01
   :ARG1 (c / compete-01
            :ARG0 (w / woman)
            :mod (p / preliminary)
            :time (t / today)
            :mod (p2 / polo
                     :mod (w2 / water)))
   :ARG2 (a / and
            :op1 (d / defeat-01
                    :ARG0 (t2 / team
                              :mod (c2 / country
                                       :wiki +
                                       :name (n / name
                                                :op1 "Hungary")))
                    :ARG1 (t3 / team
                              :mod (c3 / country
                                       :wiki +
                                       :name (n2 / name
                                                 :op1 "Canada")))
                    :quant (s / score-entity
                              :op1 13
                              :op2 7))
            :op2 (d2 / defeat-01
                     :ARG0 (t4 / team
                               :mod (c4 / country
                                        :wiki +
                                        :name (n3 / name
                                                  :op1 "France")))
                     :ARG1 (t5 / team
                               :mod (c5 / country
                                        :wiki +
                                        :name (n4 / name
                                                  :op1 "Brazil")))
                     :quant (s2 / score-entity
                                :op1 10
                                :op2 9))
            :op3 (d3 / defeat-01
                     :ARG0 (t6 / team
                               :mod (c6 / country
                                        :wiki +
                                        :name (n5 / name
                                                  :op1 "Australia")))
                     :ARG1 (t7 / team
                               :mod (c7 / country
                                        :wiki +
                                        :name (n6 / name
                                                  :op1 "Germany")))
                     :quant (s3 / score-entity
                                :op1 10
                                :op2 8))
            :op4 (d4 / defeat-01
                     :ARG0 (t8 / team
                               :mod (c8 / country
                                        :wiki +
                                        :name (n7 / name
                                                  :op1 "Russia")))
                     :ARG1 (t9 / team
                               :mod (c9 / country
                                        :wiki +
                                        :name (n8 / name
                                                  :op1 "Netherlands")))
                     :quant (s4 / score-entity
                                :op1 7
                                :op2 6))
            :op5 (d5 / defeat-01
                     :ARG0 (t10 / team
                                :mod (c10 / country
                                          :wiki +
                                          :name (n9 / name
                                                    :op1 "United"
                                                    :op2 "States")))
                     :ARG1 (t11 / team
                                :mod (c11 / country
                                          :wiki +
                                          :name (n10 / name
                                                     :op1 "Kazakhstan")))
                     :quant (s5 / score-entity
                                :op1 10
                                :op2 5))
            :op6 (d6 / defeat-01
                     :ARG0 (t12 / team
                                :mod (c12 / country
                                          :wiki +
                                          :name (n11 / name
                                                     :op1 "Italy")))
                     :ARG1 (t13 / team
                                :mod (c13 / country
                                          :wiki +
                                          :name (n12 / name
                                                     :op1 "New"
                                                     :op2 "Zealand")))
                     :quant (s6 / score-entity
                                :op1 12
                                :op2 2))))
"""

if __name__ == "__main__":
    for _ in range(5):
        smatch_score = calculate_smatch([s], [s])
        print(smatch_score)

Output

{'smatch_precision': 0.8866666666666667, 'smatch_recall': 0.8866666666666667, 'smatch_fscore': 0.8866666666666667, 'ratio_invalid_amrs': 0.0}
{'smatch_precision': 0.88, 'smatch_recall': 0.88, 'smatch_fscore': 0.88, 'ratio_invalid_amrs': 0.0}
{'smatch_precision': 0.8666666666666667, 'smatch_recall': 0.8666666666666667, 'smatch_fscore': 0.8666666666666667, 'ratio_invalid_amrs': 0.0}
{'smatch_precision': 0.9266666666666666, 'smatch_recall': 0.9266666666666666, 'smatch_fscore': 0.9266666666666666, 'ratio_invalid_amrs': 0.0}
{'smatch_precision': 0.8533333333333334, 'smatch_recall': 0.8533333333333334, 'smatch_fscore': 0.8533333333333335, 'ratio_invalid_amrs': 0.0}

I submitted this issue over at smatch (https://github.com/snowblink14/smatch/issues/43).

flipz357 commented 1 year ago

Hi,

I implemented SMATCH++, that also contains an ILP solver. It should be very simple to run.

Please do:

pip install mip==1.13.0
pip install smatchpp

Then simply run:

python -m smatchpp -a largeamr.txt -b largeamr.txt -solver ilp

The output is

F1: 100.0 Precision: 100.0 Recall: 100.0

goodmami commented 1 year ago

@flipz357 thanks for the suggestion

@BramVanroy this seems to be an issue with Smatch, not with Penman, unless I'm mistaken. I think Smatch also deinverts most triples, too. Shall we close this issue, then?

I also think we should move away from depending on the -of inversions for downstream processing of the graph, as they are merely an artifact of configuring the graph to be tree-like for serialization. In Goodman, 2019, near the end of section 2.2, I noted that annotators may prefer particular serializations to capture nuance or to align with surface strings, but I think this should be the exception rather than the rule, and such nuance differences should instead be captured explicitly.

BramVanroy commented 1 year ago

@goodmami Thanks for getting back. Sorry for cluttering the thread. There is a specific penman question hidden in my first post. Can the automatic inversion of -of be disabled one way or another?

goodmami commented 1 year ago

Can the automatic inversion of -of be disabled one way or another?

The noop model prevents the automatic de-inversion when parsing PENMAN:

>>> import penman
>>> from penman.models.noop import model as noop_model
>>> # default
>>> penman.decode("(a / alpha :ARG0-of (b / beta))").edges()
[Edge(source='b', role=':ARG0', target='a')]
>>> # with noop.model
>>> penman.decode("(a / alpha :ARG0-of (b / beta))", model=noop_model).edges()
[Edge(source='a', role=':ARG0-of', target='b')]

But it sounds like you want to stop the inversion of triples when encoding? The only thing the NoOpModel class does is override the default deinvert() method so it does not de-invert triples on parsing, so it would not help here. You could change the graph's top or its epigraph to try and get the serialization you want:

>>> g0 = penman.decode("(a / alpha :ARG0-of (b / beta))")
>>> g0.triples
[('a', ':instance', 'alpha'), ('b', ':ARG0', 'a'), ('b', ':instance', 'beta')]
>>> print(penman.encode(penman.Graph(g0.triples)))
(a / alpha
   :ARG0-of (b / beta))
>>> print(penman.encode(penman.Graph(g0.triples, top="b")))
(b / beta
   :ARG0 (a / alpha))

But note that these operations are not easily automated, and changes to top are semantic changes, at least according to Smatch and the AMR Specification

Finally, in general, it's not practical to assume that the triples provided to penman.Graph will be always encoded as-is. Some inversion may be necessary to allow it to be encoded in PENMAN. E.g., all 3 of the following serializations have at least 1 inverted triple:

>>> g1 = penman.decode("(a / alpha :ARG0-of (b / beta) :ARG0-of (g / gamma))")
>>> print(penman.encode(penman.Graph(g1.triples, top="a")))
(a / alpha
   :ARG0-of (b / beta)
   :ARG0-of (g / gamma))
>>> print(penman.encode(penman.Graph(g1.triples, top="b")))
(b / beta
   :ARG0 (a / alpha
            :ARG0-of (g / gamma)))
>>> print(penman.encode(penman.Graph(g1.triples, top="g")))
(g / gamma
   :ARG0 (a / alpha
            :ARG0-of (b / beta)))
BramVanroy commented 1 year ago

This makes a lot of sense. After closer inspection of what I was trying to do, my approach seems nonsensical/not semantically plausible. So thank you for the clarification!