QData / TextAttack

TextAttack 🐙 is a Python framework for adversarial attacks, data augmentation, and model training in NLP https://textattack.readthedocs.io/en/master/
https://textattack.readthedocs.io/en/master/
MIT License
2.96k stars 397 forks source link

Compatibility issue w/ Transformers model #722

Open svenhendrikx opened 1 year ago

svenhendrikx commented 1 year ago

When running the code example shown in the Attack class' docstring, the following error occurs: textattack: Unknown if model of class <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>. To Reproduce run the following code:

import textattack
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.transformations import WordSwapEmbedding 
from textattack.search_methods import GreedyWordSwapWIR
from textattack import Attack
import transformers

# Load model, tokenizer, and model_wrapper
model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

# Construct our four components for `Attack`
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
constraints = [
    RepeatModification(),
    StopwordModification(),
    WordEmbeddingDistance(min_cos_sim=0.9)
]
transformation = WordSwapEmbedding(max_candidates=50)
search_method = GreedyWordSwapWIR(wir_method="delete")

# Construct the actual attack
attack = Attack(goal_function, constraints, transformation, search_method)                                                   

input_text = "I really enjoyed the new movie that came out last month."
label = 1 #Positive
attack_result = attack.attack(input_text, label)

This code is sourced from the documentation of the Attack class.

See error.

Expected behavior This code should run the attack and store the result in the attack_result variable.

System Information (please complete the following information):

Additional context In shared/validators.py, the following code defines the globs which are used to validate compatibility between models and goal_functions:

MODELS_BY_GOAL_FUNCTIONS = {
    (TargetedClassification, UntargetedClassification, InputReduction): [
        r"^textattack.models.helpers.lstm_for_classification.*",
        r"^textattack.models.helpers.word_cnn_for_classification.*",
        r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
    ],
    (NonOverlappingOutput, MinimizeBleu,): [
        r"^textattack.models.helpers.t5_for_text_to_text.*",
    ],
}

However, the following glob is incorrect: r"^transformers.modeling_\w*\.\w*ForSequenceClassification$"

The correct module path for models in Transformers is as follows: r"transformers.models.\w*.modeling_\w*.\w*ForSequenceClassification"

Suggested fix:

Add the correct module path to the list of globs like so:

MODELS_BY_GOAL_FUNCTIONS = {
    (TargetedClassification, UntargetedClassification, InputReduction): [
        r"^textattack.models.helpers.lstm_for_classification.*",
        r"^textattack.models.helpers.word_cnn_for_classification.*",
        r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
        r"transformers.models.\w*.modeling_\w*.\w*ForSequenceClassification"
    ],
    (NonOverlappingOutput, MinimizeBleu,): [
        r"^textattack.models.helpers.t5_for_text_to_text.*",
    ],
}

This fixed the issue for me, let me know if I missed anything!

I have a branch with this fix already, so let me know if I can help/make a PR.

jxmorris12 commented 1 year ago

@svenhendrikx thanks for opening an issue! The issue is very clear; this is indeed a bug, and our warning message (and these regular expressions) were written for a prior version of transformers. Feel free to open a PR, it would be great to fix this, since it's an annoying warning message!