bjascob / amrlib

A python library that makes AMR parsing, generation and visualization simple.
MIT License
216 stars 33 forks source link

FAA Aligner Downcasts Token Text #53

Closed plandes closed 1 year ago

plandes commented 1 year ago

First, thank you very much for writing this great software--it's been extremely helpful in my research.

The FAA aligner API appears to return the surface alignment token text as lower case. There is a strip().lower() call in amrlib.alignments.faa_aligner.feat2tree.align which appears to be making this change. Is there a way to return the Penman graph string with case intact?

bjascob commented 1 year ago

See #19. This points to a few places in the code to change if you want to do this yourself or there's another suggested work-around.

At some point in the future I'll probably update the FAA aligner to use some of the newer AMR-3 gold alignments and I'll take a look at this. I won't have time to look into it any time soon.

plandes commented 1 year ago

One idea I had was to reuse the rule based aligner to add in the alignment using the graph and the alignment as the space separated <tok>-<path>. I started to write it myself, but r nomenclature alignments are adding more time than I was anticipating.

bjascob commented 1 year ago

The work-around should only be a few lines of code to implement. It might not be the "right" solution but it should work fine. Just lower-case match the output aligned graph's attributes to the original graph's attribs and copy over the corrected case as needed. If you're not familiar with the penman library you'll have to look into that but it should be something like...

for tnew in graph_new.triples:
    for torig in graph_orig.triples:
       # triple is a list format (source, role, target)
       if tnew[2].lower() == torig[2].lower() and tnew[2] != torig[2]:
           tnew[2] = torig[2]

Technically you only need to look at attributes so you could add logic to filter for those. See penman graph.py

plandes commented 1 year ago

@bjascob Thanks for the great suggestion, but I decided it to do the "right" way anyway. I didn't think it was going to take too long, but its more involved than you might think:

  1. There are a lot of edge cases
  2. FAA isn't consistent in creating inverse roles (from the generated graph string--see below regarding testing)
  3. The Penman tree data structure (at least for me) was challenging to deal with.

I now have a class that takes a Penman graph and populates the alignment in the nodes and roles.

I have tested it on the following corpora:

It was tested by:

  1. Use FAA to create the graph string and alignments string.
  2. Create a list of the alignments with the new code.
  3. For each created alignment in the list report those missing in the FAA generated graph.

I've written a unit test case for it for this testing.

Do you want this code in your package or should I just keep it to myself?

bjascob commented 1 year ago

If you want to drop a tar-ball in here with your code, I'll take a look at some point in the future.

plandes commented 1 year ago

@bjascob its only two files. Here's the code that populates the alignments:

"""Includes classes to add alginments to AMR graphs using an ISI formatted
alignment string.

"""
__author__ = 'Paul Landes'

from typing import List, Tuple, Union, Dict
from dataclasses import dataclass, field
import logging
import collections
from itertools import chain
from io import StringIO
from typing import Any
from pprint import pprint
import penman
from penman import Graph
from penman.surface import Alignment, RoleAlignment

logger = logging.getLogger(__name__)

@dataclass(order=True)
class PathAlignment(object):
    """An alignment that contains the path and alignment to node, or an edge for
    role alignments.

    """
    index: int = field()
    """The index of this alignment in the ISI formatted alignment string."""

    path: Tuple[int] = field()
    """The path 0-index path to the node or the edge."""

    alignment_str: str = field()
    """The original unparsed alignment."""

    alignment: Union[Alignment, RoleAlignment] = field()
    """The alignment of the node or edge."""

    triple: Tuple[str, str, str] = field()
    """The triple specifying the node or edge of the alignment."""

    @property
    def is_role(self) -> bool:
        """Whether or not the alignment is a role alignment."""
        return isinstance(self.alignment, RoleAlignment)

    def __str__(self) -> str:
        return (f'{self.index}: {self.alignment_str} -> {self.alignment} @ ' +
                f"{self.triple} (t={'role' if self.is_role else 'alignment'})")

    def __repr__(self) -> str:
        return self.__str__()

@dataclass
class AlignmentPopulator(object):
    """Adds alignments from an ISI formatted string.

    """
    graph: Graph = field()
    """The graph that will be populated with alignments."""

    alignment_key: str = field(default='alignments')
    """The key in the graph's metadata with the ISI formatted alignment string.

    """
    def __post_init__(self):
        self._strip_alignments = False

    def _get_role_src_concept(self, node: Tuple):
        """Return the role, source and concept from the tree node."""
        is_role = node[0][0] == ':'
        if is_role:
            # roles have no concepts so we must traverse to another node
            return node[0], None, None
        elif node[1][0][0] == '/':
            # "/<concept>" found as first child node
            return None, node[0], node[1][0]
        else:
            raise ValueError(f'No concept found in <{node}>')

    def _strip_align(self, s: str) -> str:
        """Return the removed the alignment string from ``s``.  This is only useful for
        testing.

        """
        if self._strip_alignments:
            ix = s.find('~')
            if ix > -1:
                s = s[:ix]
        return s

    def _get_node(self, node: Tuple[str, Any], path: Tuple[int],
                  parent: str) -> Tuple[bool, str, str, str]:
        """Return a triple of a node or role edge by recursively traversing the tree
        data structure.

        :param node: the current traversing node

        :param path: used to guide the node being searched for of the remaining
                     subtree

        :param parent: the parent of ``node``

        """
        if logger.isEnabledFor(logging.DEBUG):
            sio = StringIO()
            sio.write(f'traversing node or edge <{node[0]}>, path={path}:\n')
            pprint(node, stream=sio)
            logger.debug(sio.getvalue().strip())

        trip: Tuple[str, str, str]
        plen: int = len(path)
        role_path: bool = plen > 0 and path[0] == 'r'
        role, src, ctup = self._get_role_src_concept(node)
        src: str
        ctup: Tuple[str, str]
        role: str = None if role is None else self._strip_align(role)
        concept: str = None if ctup is None else self._strip_align(ctup[1])
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'source: {src}, role: {role}, concept: {concept}')
        if role is not None:
            if isinstance(node[1], str):
                # the alignment is on value
                trip = (role_path, parent, role, self._strip_align(node[1]))
            elif role_path:
                # either the alignment is on the role edge
                trip = (True, parent, role, node[1][0])
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'role alignment found: {trip}')
            else:
                # or traverse to the concept node as its only and first child
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'traversing role: {role}')
                trip = self._get_node(node[1], path, None)
        elif plen == 0:
            # land on a concept node with the alignment
            trip = (False, src, ':instance', concept)
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'concept node found: {trip}')
        else:
            next_idx: Union[int, str] = path[0]
            child = node[1][1:][next_idx]
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'traversing child: {next_idx}, ' +
                             f'children({len(child)}): {str(child)[:60]}')
            trip = self._get_node(child, path[1:], src)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'return trip: {trip}')
            logger.debug('_' * 40)
        return trip

    def get_node(self, path: Tuple[int]) -> Tuple[Tuple[str, str, str], bool]:
        """Get a triple representing a graph node/edge of the given path.

        :param path: a tuple of 0-based index used to get a node or edge.

        :return: the node/edge triple and ``True`` if ia role edge

        """
        tree: penman.tree.Tree = penman.configure(self.graph)
        node: Tuple = self._get_node(tree.nodes()[0], path[1:], None)
        return node[1:], node[0]

    def _merge_aligns(self, pas: List[PathAlignment]) -> List[PathAlignment]:
        """Merges nodes in ``pas`` with the same triples in to alignments with multiple
        indices.

        """
        by_trip_role: Dict[Tuple, PathAlignment] = collections.defaultdict(
            lambda: collections.defaultdict(list))
        colls: List[PathAlignment] = []
        pa: PathAlignment
        for pa in pas:
            by_trip_role[pa.triple][pa.is_role].append(pa)
        coll: List[PathAlignment]
        groups = chain.from_iterable(
            map(lambda r: r.values(), by_trip_role.values()))
        for coll in groups:
            if len(coll) > 1:
                aixs = tuple(chain.from_iterable(
                    sorted(map(lambda pa: pa.alignment.indices, coll))))
                coll[0].alignment.indices = aixs
            colls.append(coll[0])
        colls.sort()
        return colls

    def _fix_inverses(self, pas: List[PathAlignment]):
        """The FAA aligner is not consistent in which nodes are reversed for
        ``:<name>-of`` roles.  At least ``part`` and ``location`` have this
        issue.  This is only used when testing.

        """
        epis: Dict[Tuple[str, str, str], List] = self.graph.epidata
        pa: PathAlignment
        for pa in pas:
            trip: Tuple[str, str, str] = pa.triple
            role: str = trip[1]
            if role[0] == ':' and role.endswith('-of') and trip not in epis:
                pa.triple = (trip[2], role[0:-3], trip[0])

    def get_alignments(self) -> Tuple[PathAlignment]:
        """Return the alignments for the graph."""
        graph = self.graph
        insts = {i.source: i for i in graph.instances()}
        assert len(graph.instances()) == len(insts)
        tree: penman.tree.Tree = penman.configure(graph)
        if logger.isEnabledFor(logging.DEBUG):
            sio = StringIO()
            sio.write('graph:\n')
            sio.write(penman.encode(graph, model=penman.models.noop.model))
            sio.write('\nepis:\n')
            for i in graph.epidata.items():
                sio.write(f'{i}\n')
            logger.debug(sio.getvalue().strip())
        aligns: List[str] = graph.metadata[self.alignment_key].split()
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'alignments: {aligns}')
        path_aligns: List[PathAlignment] = []
        align: str
        for paix, align in enumerate(aligns):
            ixs, path = align.split('-')
            path = tuple(map(lambda s: s if s == 'r' else (int(s) - 1),
                             path.split('.')))
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'search alignment: {align}')
            type_targ_trip = self._get_node(tree.nodes()[0], path[1:], None)
            role_align, targ_trip = type_targ_trip[0], type_targ_trip[1:]
            align_cls = RoleAlignment if role_align else Alignment
            align_inst = align_cls.from_string(f'e.{ixs}')
            pa = PathAlignment(paix, path, align, align_inst, targ_trip)
            path_aligns.append(pa)
        if self._strip_alignments:
            # only necessary when testing
            self._fix_inverses(path_aligns)
        path_aligns = self._merge_aligns(path_aligns)
        return tuple(path_aligns)

    def __call__(self) -> Tuple[PathAlignment]:
        """Add the alignments to the graph using the ISI formatted alignemnt string.

        :return: the alignments added to the graph

        """
        epis: Dict[Tuple[str, str, str], List] = self.graph.epidata
        pas: List[PathAlignment] = self.get_alignments()
        pa: PathAlignment
        for pa in pas:
            epi = epis.get(pa.triple)
            epi.append(pa.alignment)
        return pas

    def get_missing_alignments(self) -> Tuple[PathAlignment]:
        """Find all path alignments not in the graph.  This is done by matching against
        the epi mapping.  This is only useful for testing.

        """
        def filter_align(epi: Any) -> bool:
            return isinstance(epi, (RoleAlignment, Alignment))

        missing: List[PathAlignment] = []
        epis: Dict[Tuple[str, str, str], List] = self.graph.epidata
        pas: List[PathAlignment]
        try:
            self._strip_alignments = True
            pas = self.get_alignments()
        finally:
            self._strip_alignments = True
        for pa in pas:
            targ_trip: Tuple[str, str, str] = pa.triple
            prev_epis: List = epis.get(targ_trip)
            if prev_epis is None:
                raise ValueError(f'Target not found: {targ_trip}')
            prev_aligns: Tuple[Union[RoleAlignment, Alignment]] = \
                tuple(filter(filter_align, prev_epis))
            if pa.alignment not in prev_aligns:
                missing.append(pa)
        return tuple(missing)
plandes commented 1 year ago

Here's the test case:

from typing import List
import unittest
from penman.graph import Graph
from penman.model import Model
import penman.models.noop
from amrlib.graph_processing.amr_loading import load_amr_entries
from amrlib.alignments.faa_aligner import FAA_Aligner
from . import AlignmentPopulator, PathAlignment

class TestAlignmentPopulation(unittest.TestCase):
    def test_align(self):
        path: str = 'data/corpus/amr-bank-struct-v3.0.txt'
        model: Model = penman.models.noop.model
        graph_strs: List[str] = load_amr_entries(str(path))
        ugraphs: List[Graph] = [penman.decode(gs, model=model) for gs in graph_strs]
        sents: List[str] = [g.metadata['snt'] for g in ugraphs]
        inference = FAA_Aligner()
        agraphs, aligns = inference.align_sents(sents, graph_strs)
        tups = zip(sents, ugraphs, agraphs, aligns)
        for i, (sent, ugraph, agraph_str, align) in enumerate(tups):
            agraph: Graph = penman.decode(agraph_str, model=model)
            agraph.metadata['snt'] = sent
            # fast align leaves a space at the end of the alignment string
            agraph.metadata['alignments'] = align.strip()
            ap = AlignmentPopulator(graph=agraph)
            aligns: List[PathAlignment]
            try:
                aligns = ap.get_missing_alignments()
            except Exception as e:
                print(f'failed parse ({i}): {e}')
                print(penman.encode(agraph, model=model))
                self.assertTrue(False)
            misses = tuple(filter(lambda p: not p.exists, aligns))
            if len(misses) > 0:
                print(f'found {len(misses)} missing alignments (i={i}):')
                for miss in misses:
                    print(f'missing: {miss}')
                print(penman.encode(agraph, model=model))
                self.assertTrue(False)
bjascob commented 1 year ago

Todo - review code for future integration.

bjascob commented 1 year ago

Keeping the code here for now in case someone has a similar need in the future.