QData / TextAttack

TextAttack 🐙 is a Python framework for adversarial attacks, data augmentation, and model training in NLP https://textattack.readthedocs.io/en/master/
https://textattack.readthedocs.io/en/master/
MIT License
2.98k stars 397 forks source link

`flair`-related issues with Word Swap by Inflections #713

Open lindsaydbrin opened 1 year ago

lindsaydbrin commented 1 year ago

Describe the bug

I ran into two flair-related issues while using the Word Swap by Inflections transformation. The first one required a flair update, and the second required a small textattack code change.

I'd like to suggest requiring a newer version of flair, as well as a small code change. Alternatively, if a specific older version of flair is preferred and works with the textattack code as is, specifying that version would also solve the problem.

First issue

On line 220 of textattack/shared/utils/strings.py, textattack calls the predict() method of flair's SequenceTagger and passes the argument force_token_predictions=True. I happened to be using an older version of flair (0.6) for which the predict method in question does not have the parameter force_token_prediction, which means I got:

predict() got an unexpected keyword argument 'force_token_predictions'

Updating to the current version (0.11) addressed the issue.

Second issue:

Once the above was solved, the inflections word swap did not return any transformations, only the original input. Tracing the problem, I noticed that WordSwapInflections._get_replacement_words() gets part of speech (POS) from flair and checks it against WordSwapInflections._enptb_to_universal, which converts POS tags from the Penn Treebank to universal POS tags. However, flair (v. 0.11) seems to be returning universal POS tags, so none of the words' POS tags were ever found in the list of those to transform, so no transformations were made. (See PyCharm screenshot below.)

As a quick solution, I (locally) changed line 57 to reference the dictionary's values (instead of keys), and line 68 to lemminflect_pos = word_part_of_speech directly. A better solution would probably be to replace self._enptb_to_universal with an array with the relevant universal POS tags, and reference that directly.

This screenshot is from mid-debugging a test that calls the inflection perturbation. You can see that word_part_of_speech is "NOUN", which is in the values, not keys, of self._enptb_to_universal. (And, in fact, changing the code to reference the dictionary values addressed the issues and made the test pass.)

Screen Shot 2023-01-27 at 10 29 01 AM

To Reproduce

You can reproduce the first issue by running code or by checking the code directly.

So either:

  1. Install an older version of flair: 0.6
  2. Run the following code:
    from textattack.augmentation import Augmenter
    from textattack.transformations import WordSwapInflections
    augmenter = Augmenter(transformation = WordSwapInflections(),
    transformations_per_example=10
    )
    augmenter.augment("What I cannot create, I do not understand.")
  3. See error.

~ OR ~

  1. Download and look at the source code for flair version 0.6.2, specifically flair/models/sequence_tagger_model.py lines 299-308, which do not include the parameter force_token_predictions.
  2. Look at the source code for flair version 0.11, specifically lines 427-437, which include the parameter force_token_predictions.
  3. Look at textattack's code that calls this method, and observe the problem...

For the second issue, presuming flair is up to date:

  1. Run the same code as above:
    from textattack.augmentation import Augmenter
    from textattack.transformations import WordSwapInflections
    augmenter = Augmenter(transformation = WordSwapInflections(),
    transformations_per_example=10
    )
    augmenter.augment("What I cannot create, I do not understand.")
  2. Note that the transformation returns the original utterance and no transformations.

Expected behavior

With the above code, I expect the result to be several transformations (up to the number requested) with inflectional perturbations.

Note that I was able to get this once I updated flair and changed the two lines of code referenced in the Describe the Bug section above.

System Information (please complete the following information):

Additional context

Let me know if you want me to code this out and do a PR vs. whether someone else can do the fix easily (i.e. has the environment already set up)!

Also, this is what I was using as a test:

def test_textattack_inflections():
    augmenter_inflections = Augmenter(
        transformation=WordSwapInflections(),
        transformations_per_example=3,
    )
    sentence = "This is a sentence with the word schedule in it. I saw it written down."
    result = augmenter_inflections.augment(sentence)
    assert len(result) == 3, "Did not return correct number of transformed sentences."
    assert (
        sentence != result[0] != result[1] != result[2]
    ), "Transformations should all be different from each other and from the original sentence."
jxmorris12 commented 1 year ago

Hi @lindsaydbrin - thank you so much for the detailed error message, this is super explanatory and helpful. The flair versioning has been a constant headache for us, but what you've described seems super straightforward to fix. Hopefully someone can help us.

lindsaydbrin commented 1 year ago

Hi @jxmorris12 - No problem, happy to help! What do you mean that hopefully someone can help us? I think with the textattack code changes I described, it should fix the problem(s).

I could do it, I just don't have the environment set up, so I thought it might be easy for someone (you or otherwise) to make the change quickly. But if that's not easy, I don't mind. (Or am I missing something?)

jxmorris12 commented 1 year ago

oh yes, I meant if someone puts up a pull request! I will have time eventually but likely can't make the changes this week.

code-chendl commented 1 year ago

As for the second issue proposed in https://github.com/QData/TextAttack/issues/713#issue-1574652467, I replace self._enptb_to_universal with an array with the relevant universal POS tags, in the \textattack\transformations\word_swaps\word_swap_inflections.py file, as the following code.

        self._enptb_to_universal = {
          #-----dictionary with new tags---------
            "PUNCT": ".", 
            "CCONJ": "CONJ",
            "SCONJ": "CONJ",
            "PROPN": "NOUN",
            "PART": "PRT",
            "AUX": "VERB",
            "SYM": "NOUN",
            "INTJ":"X", 
           #----original dictionary below----------
            "JJRJR": "ADJ",
            "VBN": "VERB",
            "VBP": "VERB",
            "JJ": "ADJ",
            "VBZ": "VERB",
            "VBG": "VERB",
            "NN": "NOUN",
            "VBD": "VERB",
            "NP": "NOUN",
            "NNP": "NOUN",
            "VB": "VERB",
            "NNS": "NOUN",
            "VP": "VERB",
            "TO": "VERB",
            "MD": "VERB",
            "NNPS": "NOUN",
            "JJS": "ADJ",
            "JJR": "ADJ",
            "RB": "ADJ",
        }

(Mapping info: https://github.com/slavpetrov/universal-pos-tags and https://zhuanlan.zhihu.com/p/427520069.)