Trusted-AI / adversarial-robustness-toolbox

Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams
https://adversarial-robustness-toolbox.readthedocs.io/en/latest/
MIT License
4.82k stars 1.16k forks source link

PyTorchEstimator with non-PyTorch preprocessor does not compute gradient correctly #1155

Closed dxoigmn closed 3 years ago

dxoigmn commented 3 years ago

Describe the bug In PyTorchEstimator, the _apply_preprocessing_gradient does not appear to work correctly when a non-PyTorch preprocessor is used with a PyTorch estimator. In particular, when computing the gradient, the estimator uses the preprocessor's estimate_gradient but the estimator does not store intermediate represents to feed these calls. Rather, for each preprocessor, it always passes the original input (x in the snippet below), which I believe is wrong: https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/24a616befec9d35ba053d16d4a6ac6cfc573039c/art/estimators/pytorch.py#L266-L273

In many cases, a preprocessor usually just returns the same shape as the input, so this silently appears to work. We found this because our preprocessor changes the shape of the input (e.g., a preprocessor that collapses an RGB image or video into a grayscale image or video), and that loudly exposes the bug. Here's the (truncated) error with some custom debugging output:

preprocessing_operations = [<oscar.defences.preprocessor.Preprocessor object at 0x7f3e78a04d90>, StandardisationMeanStdPyTorch(mean=0.0, std=1.0, apply_fit=True, apply_predict=True, device=cuda:0)]
x = (1, 220, 240, 320, 3)
gradients = (1, 220, 240, 320, 1)
Traceback (most recent call last):
  File "~/Projects/GARD/lib/armory/armory/scenarios/base.py", line 84, in evaluate
    config, num_eval_batches, skip_benign, skip_attack, skip_misclassified
  File "~/Projects/GARD/oscar/scenarios/video_ucf101_scenario_science.py", line 236, in _evaluate
    x_adv = attack.generate(x=x_adv, **generate_kwargs) # Begin attack search from existing adversarial example
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/attacks/attack.py", line 75, in replacement_function
    return fdict[func_name](self, *args, **kwargs)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/attacks/evasion/projected_gradient_descent/projected_gradient_descent.py", line 181, in generate
    return self._attack.generate(x=x, y=y, **kwargs)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/attacks/attack.py", line 75, in replacement_function
    return fdict[func_name](self, *args, **kwargs)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_numpy.py", line 327, in generate
    self.num_random_init > 0 and i_max_iter == 0,
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/attacks/evasion/fast_gradient.py", line 468, in _compute
    perturbation = self._compute_perturbation(batch, batch_labels, mask_batch)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/attacks/evasion/fast_gradient.py", line 369, in _compute_perturbation
    grad = self.estimator.loss_gradient(batch, batch_labels) * (1 - 2 * int(self.targeted))
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/estimators/classification/pytorch.py", line 743, in loss_gradient
    grads = self._apply_preprocessing_gradient(x, grads)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/estimators/pytorch.py", line 277, in _apply_preprocessing_gradient
    gradients = preprocess.estimate_gradient(x, gradients)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/defences/preprocessor/preprocessor.py", line 227, in estimate_gradient
    x_grad[:, i, ...] = get_gradient(x=x, grad=grad[:, i, ...])
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/art/defences/preprocessor/preprocessor.py", line 207, in get_gradient
    x_prime.backward(grad)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/torch/autograd/__init__.py", line 126, in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_)
  File "~/Projects/GARD/.venv/lib/python3.7/site-packages/torch/autograd/__init__.py", line 37, in _make_grads
    + str(out.shape) + ".")
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 240, 320, 1]) and output[0] has a shape of torch.Size([1, 220, 240, 320, 3]).

Our preprocessor oscar.defences.preprocessor.Preprocessor takes input [1, 220, 240, 320, 3] and turns it into output [1, 220, 240, 320, 1].

To Reproduce I don't have a small test case to exhibit this problem, but I imagine one could create a non-pytorch preprocessor that collapses an image into grayscale input that exposes the issue and fails loudly.

Expected behavior A clear and concise description of what you expected to happen.

Screenshots N/A

System information (please complete the following information):

beat-buesser commented 3 years ago

Hi @dxoigmn Thank you very much for using ART and raising this issue!

Your observation and descriptions are correct. We are aware of the situation and have, with the introduction of PyTorch-specific preprocessing in 1.4, added a restriction to allow only one Numpy-based preprocessors or one Numpy-based preprocessors followed by the Mean/Std-Normalisation preprocessor a few lines before (assuming), for which passing the original input x should be sufficient: https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/24a616befec9d35ba053d16d4a6ac6cfc573039c/art/estimators/pytorch.py#L262-L265

Because you are reaching line 266 I assume that you are using only one Numpy-based preprocessor, which should be ok. Something I have noticed in your stack-trace is grad_output[0] has a shape of torch.Size([1, 240, 320, 1]) and output[0] has a shape of torch.Size([1, 220, 240, 320, 3]). Is it correct that the grad_output does not have a dimension for the frames (e.g. of size 220)?

beat-buesser commented 3 years ago

@dxoigmn Btw, I have added support for multiple layers in the FeatureAdversariesPyTorch and FeatureAdversariesTensorFlowV2 attacks in #1156 in case you want to try it out, they'll be part of ART 1.7.0. Let me know what you think!

dxoigmn commented 3 years ago

@beat-buesser: Yes, our preprocessor uses the numpy interface (it is actually a PyTorch implementation which we need to do in order to manage GPU memory).

If I unroll that loop, that computation that happens is essentially:

x = (1, 220, 240, 320, 3)
gradients = (1, 220, 240, 320, 1)

gradients_1 = StandardisationMeanStdPyTorch.estimate_gradient(x, gradients)
gradients_2 = oscar.defences.preprocessor.Preprocessor.estimate_gradient(x, gradients_1)

Notice that it is impossible to compute gradients_1 using StandardisationMeanStdPyTorch since StandardisationMeanStdPyTorch does not to change the shape yet it's inputs are of different shape! What we really want to to compute is:

x_1 = oscar.defences.preprocessor.Preprocessor.estimate_forward(x) # Or just forward

gradients_1 = StandardisationMeanStdPyTorch.estimate_gradient(x_1, gradients)
gradients_2 = oscar.defences.preprocessor.Preprocessor.estimate_gradient(x, gradients_1)

Does that make sense?

Another way to show that part of the code is most likely incorrect to check whether the gradients from a PyTorchEstimator + PreprocessorPyTorch are the same when all_framework_preprocessing is True and False (i.e., by manually setting that variable's value in the PyTorch Estimator). I'll try to create this minimal test case to verify.

Also, thank you for the feature adversaries lead. I will take a look at it.

dxoigmn commented 3 years ago

Okay, I think I have a minimal example that breaks:

import torch

from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
from art.estimators.classification.pytorch import PyTorchClassifier

class ReducingPreprocessor(PreprocessorPyTorch):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self._device = 'cpu'

    def forward(self, x, y=None):
        x = torch.mean(x, dim=-1, keepdim=True)

        return x, y

class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.linear = torch.nn.Linear(32*32, 10)

    def forward(self, x, y=None):
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)

        return x

preprocessor = ReducingPreprocessor()
model = LinearModel()

classifier = PyTorchClassifier(model=model,
                               input_shape=(None, 32, 32, 3),
                               nb_classes=10,
                               loss=torch.nn.CrossEntropyLoss(),
                               preprocessing_defences=[preprocessor])

x = torch.rand(1, 32, 32, 3)
y = torch.nn.functional.one_hot(torch.tensor([1]), 10)

classifier.loss_gradient(x, y) # works!

classifier.all_framework_preprocessing = False
classifier.loss_gradient(x.numpy(), y.numpy()) # doesn't work!
beat-buesser commented 3 years ago

Hi @dxoigmn Thank you very much for the example, this is very helpful! This is an unexpected use of these tool, but I think I understand now what you are doing and I have a few questions for my understanding:

import torch
import numpy as np

from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
from art.estimators.classification.pytorch import PyTorchClassifier

class ReducingPreprocessor(PreprocessorPyTorch):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self._device = 'cpu'

    def forward(self, x, y=None):
        x = torch.mean(x, dim=-1, keepdim=True)

        return x, y

    def estimate_gradient(self, x: np.ndarray, grad: np.ndarray) -> np.ndarray:
        import torch  # lgtm [py/repeated-import]

        x = torch.tensor(x, device=self.device, requires_grad=True)
        grad = torch.tensor(grad, device=self.device)

        x_prime = self.estimate_forward(x)
        x_prime.backward(grad)
        x_grad = x.grad.detach().cpu().numpy()

        return x_grad

class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.linear = torch.nn.Linear(32*32, 10)

    def forward(self, x, y=None):
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)

        return x

preprocessor = ReducingPreprocessor()
model = LinearModel()

classifier = PyTorchClassifier(model=model,
                               input_shape=(32, 32, 3),
                               nb_classes=10,
                               loss=torch.nn.CrossEntropyLoss(),
                               preprocessing=None,
                               preprocessing_defences=[preprocessor])

x = torch.rand(1, 32, 32, 3)
y = torch.nn.functional.one_hot(torch.tensor([1]), 10)

classifier.loss_gradient(x, y) # works!

classifier.all_framework_preprocessing = False
classifier.loss_gradient(x.numpy(), y.numpy()) # now it works too, but experimental!
dxoigmn commented 3 years ago

Hi @dxoigmn Thank you very much for the example, this is very helpful! This is an unexpected use of these tool, but I think I understand now what you are doing and I have a few questions for my understanding:

  • Are you setting classifier.all_framework_preprocessing = False for the purpose of the example or also in your application? This attribute was not intended to be modified by the user and leaving it set to True in this case would solve the issue.

I just set that flag for purposes of the example. I could have triggered that condition by changing the preprocessor to use the numpy interface, but I'm lazy :)

  • Do you need normalisation with StandardisationMeanStd* before the sample gets to the model or could it be integrated in the preprocessing step? You can avoid StandardisationMeanStd* being added automatically if you set preprocessing=None in PyTorchClassifier, that way only ReducingPreprocessor will be applied.

No normalization is needed. And, yes, to solve this issue I set preprocessing=None and implemented estimate_gradients as you show.

In theory though, I don't see why it shouldn't work with preprocessing=(0., 1.) since that normalization is essentially a no-op. But the code, as written, makes an additional assumption that preprocessors will always return the same shape as the input. It's a fine assumption, but I don't think that is enforced or documented anywhere and probably should be.

Maybe my title is a misnomer in that it does compute the gradient correctly but will not work with certain types of preprocessors?

beat-buesser commented 3 years ago

@dxoigmn I agree with you that the implicit assumptions are not documented as well as they could be, we'll try to improve it. I think it calculates the gradients correctly, it just assumes that if the shapes are different then it must be class gradients with an additional dimension to loop over. We'll have to think of a solution distinguishing a case like yours from class gradients.