jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

sigmoid support #63

Closed WandernForte closed 1 year ago

WandernForte commented 1 year ago

i'm trying to modify foolbox like following,

from typing import TypeVar, Any
from abc import ABC, abstractmethod

import torch
import eagerpy as ep
from foolbox.criteria import Criterion

class TargetedMisclassificationML(Criterion):
    """Considers those perturbed inputs adversarial whose predicted class
    matches the target classes. Multi-Label

    Args:
        target_classes: Tensor with target classes ``(batch,)``.
    """

    def __init__(self, target_classes: Any):
        super().__init__()
        self.target_classes: ep.Tensor = ep.astensor(target_classes)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.target_classes!r})"

    def __call__(self, perturbed: T, outputs: T) -> T:
        outputs_, restore_type = ep.astensor_(outputs)
        del perturbed, outputs
        print(type(outputs_), type(outputs))
        # classes = outputs_.argmax(axis=-1)
        classes = torch.tensor(torch.round(torch.sigmoid(outputs_)).detach().numpy().tolist())  # ERR
        assert classes.shape == self.target_classes.shape
        is_adv = classes == self.target_classes
        return restore_type(is_adv)

will u support sigmoid function in the future? and how to implement sigmoid with eagerpy now?

WandernForte commented 1 year ago

i've solve it by self.

class TargetedMisclassificationML(Criterion):
    """Considers those perturbed inputs adversarial whose predicted class
    matches the target classes. Multi-Label

    Args:
        target_classes: Tensor with target classes ``(batch,)``.
    """

    def __init__(self, target_classes: Any):
        super().__init__()
        self.target_classes: ep.Tensor = ep.astensor(target_classes)
        _, self.restore_type = ep.astensor_(target_classes)
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.target_classes!r})"

    def __call__(self, perturbed: T, outputs: T) -> T:
        outputs_, restore_type = ep.astensor_(outputs)
        del perturbed, outputs
        classes = torch.tensor(torch.round(torch.sigmoid(self.restore_type(outputs_).to(device))).detach().numpy().tolist()).to(device)
        classes = ep.astensor(classes)
        assert classes.shape == self.target_classes.shape
        is_adv = classes == self.target_classes
        return restore_type(is_adv)