QData / TextAttack

TextAttack 🐙 is a Python framework for adversarial attacks, data augmentation, and model training in NLP https://textattack.readthedocs.io/en/master/
https://textattack.readthedocs.io/en/master/
MIT License
2.98k stars 397 forks source link

Does this library support targeted classification in its current state? #732

Open MackBlackburn opened 1 year ago

MackBlackburn commented 1 year ago

Do any of the attacks support targeted classification as a goal function? I do not see it stated anywhere but untargeted classification is mentioned. Manually switching the goal function of a recipe results in an error:

AttributeError: 'UntargetedClassification' object has no attribute 'num_queries'

Here is some basic code I am using that causes the error:


class HuggingFaceSentimentAnalysisPipelineWrapper(ModelWrapper):
    def __init__(self, model):
        self.model = model

    def __call__(self, text_inputs):
        raw_outputs = self.model(text_inputs)
        outputs = []
        for sub_l in raw_outputs:
            out_scores = [0] * len(self.model.model.config.label2id)
            for d in sub_l:
                score = d["score"]
                out_scores[self.model.model.config.label2id[d['label']]] = score
            outputs.append(out_scores)
        return np.array(outputs)

pipe = pipeline("text-classification", 'j-hartmann/emotion-english-distilroberta-base', device=0, return_all_scores=True)

model_wrapper = HuggingFaceSentimentAnalysisPipelineWrapper(pipe)
recipe = TextBuggerLi2018.build(model_wrapper)
recipe.transformation.language = "eng"
recipe.constraints = []
recipe.goal_function = TargetedClassification(model_wrapper, target_class=2)
recipe.goal_function.num_queries = 10

dataset = Dataset([('Im angry', 0)])
attacker = Attacker(recipe, dataset)
results = attacker.attack_dataset()