bjascob / amr_coref

A python library / model for creating co-references between AMR graph nodes.
MIT License
9 stars 1 forks source link

Multiprocessing Error #4

Open BWAAEEEK opened 1 year ago

BWAAEEEK commented 1 year ago

Hello,

I am a student studying GNN and recently I have been working on a project using AMR graphs.

I came across an issue while trying to use the AMR graph library in the code uploaded to a repository.

And I tried fixing it myself, but I'd like some feedback to ensure that the changes I made are correct.

Finally, I found an error that global variables are not shared when using multiprocessing.

So, I used functools.partial to pass arguments to the worker function and resolved the issue.

I'm not sure if the uploaded code was originally correct.

Below is the code that I modified.

Could you please confirm if the code below is correct? Thank you!

import re
import logging
from   multiprocessing import Pool
from   tqdm import tqdm
import numpy
from functools import partial

gfeaturizer, gmax_dist = None, None    # for multiprocessing
def build_coref_features(mdata, model, **kwargs):
    chunksize = kwargs.get('feat_chunksize',          200)
    maxtpc    = kwargs.get('feat_maxtasksperchild',   200)
    processes = kwargs.get('feat_processes',         None)    # None = use os.cpu_count()
    show_prog = kwargs.get('show_prog',              True)
    global gfeaturizer, gmax_dist

    gfeaturizer = CorefFeaturizer(mdata, model)
    gmax_dist   = model.config.max_dist if model.config.max_dist is not None else 999999999
    # Build the list of doc_names and mention indexes for multiprocessing and the output container
    idx_keys = [(dn, idx) for dn, mlist in gfeaturizer.mdata.mentions.items() for idx in range(len(mlist))]
    feat_data = {}
    for dn, mlist in gfeaturizer.mdata.mentions.items():
        feat_data[dn] = [None]*len(mlist)
    # Loop through and get the pair features for all antecedents
    pbar = tqdm(total=len(idx_keys), ncols=100, disable=not show_prog)
    with Pool(processes=processes, maxtasksperchild=maxtpc) as pool:
        worker_with_args = partial(worker, gfeaturizer=gfeaturizer, gmax_dist=gmax_dist)
        for fdata in pool.imap_unordered(worker_with_args, idx_keys, chunksize=chunksize):
            dn, midx, sspans, dspans, words, sfeats, pfeats, slabels, plabels = fdata
            feat_data[dn][midx] = {'sspans':sspans,   'dspans':dspans, 'words':words,
                                   'sfeats':sfeats,   'pfeats':pfeats,
                                   'slabels':slabels, 'plabels':plabels}
            pbar.update(1)
    pbar.close()
    # Error check
    for dn, feat_list in feat_data.items():
        assert None not in feat_list
    return feat_data

def worker(idx_key, gfeaturizer, gmax_dist):
    global gfrozen_embeds
    doc_name, midx = idx_key
    mlist       = gfeaturizer.mdata.mentions[doc_name]
    mention     = mlist[midx]               # the head mention
    antecedents = mlist[:midx]              # all antecedents up to (not including) head mention
    antecedents = antecedents[-gmax_dist:]  # truncate earlier value so list is only max_dist long
    # Process the single and pair data data
    sspan_vector = gfeaturizer.get_sentence_span_vector(mention)
    dspan_vector = gfeaturizer.get_document_span_vector(mention)
    word_indexes = gfeaturizer.get_word_indexes(mention)
    sfeats       = gfeaturizer.get_single_features(mention)
    pfeats       = gfeaturizer.get_pair_features(mention, antecedents)
    # Build target labels.  Note that if there are no clusters in the mention data this will still
    # return a list of targets, though all singles will be 1 and pairs 0
    slabels, plabels = gfeaturizer.build_targets(mention, antecedents)
    return doc_name, midx, sspan_vector, dspan_vector, word_indexes, sfeats, pfeats, slabels, plabels
bjascob commented 1 year ago

Please read the note about multiprocessing on the main project page under "Project Status" and review the issue referenced there.

BWAAEEEK commented 1 year ago

Even so, does this approach still involve using multiprocessing?

plandes commented 11 months ago

I've forked this repo and made these changes in this repo to fix this issue.

To install:

pip install zensols.amr_coref