Closed dxoigmn closed 3 years ago
Hi @dxoigmn Thank you very much for using ART and raising this issue!
Your observation and descriptions are correct. We are aware of the situation and have, with the introduction of PyTorch-specific preprocessing in 1.4, added a restriction to allow only one Numpy-based preprocessors or one Numpy-based preprocessors followed by the Mean/Std-Normalisation preprocessor a few lines before (assuming), for which passing the original input x
should be sufficient:
https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/24a616befec9d35ba053d16d4a6ac6cfc573039c/art/estimators/pytorch.py#L262-L265
Because you are reaching line 266 I assume that you are using only one Numpy-based preprocessor, which should be ok. Something I have noticed in your stack-trace is grad_output[0] has a shape of torch.Size([1, 240, 320, 1]) and output[0] has a shape of torch.Size([1, 220, 240, 320, 3])
. Is it correct that the grad_output
does not have a dimension for the frames (e.g. of size 220)?
@dxoigmn Btw, I have added support for multiple layers in the FeatureAdversariesPyTorch
and FeatureAdversariesTensorFlowV2
attacks in #1156 in case you want to try it out, they'll be part of ART 1.7.0. Let me know what you think!
@beat-buesser: Yes, our preprocessor uses the numpy interface (it is actually a PyTorch implementation which we need to do in order to manage GPU memory).
If I unroll that loop, that computation that happens is essentially:
x = (1, 220, 240, 320, 3)
gradients = (1, 220, 240, 320, 1)
gradients_1 = StandardisationMeanStdPyTorch.estimate_gradient(x, gradients)
gradients_2 = oscar.defences.preprocessor.Preprocessor.estimate_gradient(x, gradients_1)
Notice that it is impossible to compute gradients_1
using StandardisationMeanStdPyTorch
since StandardisationMeanStdPyTorch
does not to change the shape yet it's inputs are of different shape! What we really want to to compute is:
x_1 = oscar.defences.preprocessor.Preprocessor.estimate_forward(x) # Or just forward
gradients_1 = StandardisationMeanStdPyTorch.estimate_gradient(x_1, gradients)
gradients_2 = oscar.defences.preprocessor.Preprocessor.estimate_gradient(x, gradients_1)
Does that make sense?
Another way to show that part of the code is most likely incorrect to check whether the gradients from a PyTorchEstimator + PreprocessorPyTorch are the same when all_framework_preprocessing
is True
and False
(i.e., by manually setting that variable's value in the PyTorch Estimator). I'll try to create this minimal test case to verify.
Also, thank you for the feature adversaries lead. I will take a look at it.
Okay, I think I have a minimal example that breaks:
import torch
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
from art.estimators.classification.pytorch import PyTorchClassifier
class ReducingPreprocessor(PreprocessorPyTorch):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._device = 'cpu'
def forward(self, x, y=None):
x = torch.mean(x, dim=-1, keepdim=True)
return x, y
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32*32, 10)
def forward(self, x, y=None):
x = torch.flatten(x, start_dim=1)
x = self.linear(x)
return x
preprocessor = ReducingPreprocessor()
model = LinearModel()
classifier = PyTorchClassifier(model=model,
input_shape=(None, 32, 32, 3),
nb_classes=10,
loss=torch.nn.CrossEntropyLoss(),
preprocessing_defences=[preprocessor])
x = torch.rand(1, 32, 32, 3)
y = torch.nn.functional.one_hot(torch.tensor([1]), 10)
classifier.loss_gradient(x, y) # works!
classifier.all_framework_preprocessing = False
classifier.loss_gradient(x.numpy(), y.numpy()) # doesn't work!
Hi @dxoigmn Thank you very much for the example, this is very helpful! This is an unexpected use of these tool, but I think I understand now what you are doing and I have a few questions for my understanding:
Are you setting classifier.all_framework_preprocessing = False
for the purpose of the example or also in your application? This attribute was not intended to be modified by the user and leaving it set to True
in this case would solve the issue.
This is not related to the issue, but input_shape in PyTorchClassifier
should be input_shape=(32, 32, 3)
.
Do you need normalisation with StandardisationMeanStd*
before the sample gets to the model or could it be integrated in the preprocessing step? You can avoid StandardisationMeanStd*
being added automatically if you set preprocessing=None
in PyTorchClassifier
, that way only ReducingPreprocessor
will be applied.
With setting preprocessing=None
you can now make it work by overwriting PreprocessorPyTorch.estimate_gradients
because estimate_gradients
assumes that if the shapes of input and gradients don't match it has to treat the gradients as class and not loss gradients, which is not the case in this example. The following example seems to work:
import torch
import numpy as np
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
from art.estimators.classification.pytorch import PyTorchClassifier
class ReducingPreprocessor(PreprocessorPyTorch):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._device = 'cpu'
def forward(self, x, y=None):
x = torch.mean(x, dim=-1, keepdim=True)
return x, y
def estimate_gradient(self, x: np.ndarray, grad: np.ndarray) -> np.ndarray:
import torch # lgtm [py/repeated-import]
x = torch.tensor(x, device=self.device, requires_grad=True)
grad = torch.tensor(grad, device=self.device)
x_prime = self.estimate_forward(x)
x_prime.backward(grad)
x_grad = x.grad.detach().cpu().numpy()
return x_grad
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32*32, 10)
def forward(self, x, y=None):
x = torch.flatten(x, start_dim=1)
x = self.linear(x)
return x
preprocessor = ReducingPreprocessor()
model = LinearModel()
classifier = PyTorchClassifier(model=model,
input_shape=(32, 32, 3),
nb_classes=10,
loss=torch.nn.CrossEntropyLoss(),
preprocessing=None,
preprocessing_defences=[preprocessor])
x = torch.rand(1, 32, 32, 3)
y = torch.nn.functional.one_hot(torch.tensor([1]), 10)
classifier.loss_gradient(x, y) # works!
classifier.all_framework_preprocessing = False
classifier.loss_gradient(x.numpy(), y.numpy()) # now it works too, but experimental!
Hi @dxoigmn Thank you very much for the example, this is very helpful! This is an unexpected use of these tool, but I think I understand now what you are doing and I have a few questions for my understanding:
- Are you setting
classifier.all_framework_preprocessing = False
for the purpose of the example or also in your application? This attribute was not intended to be modified by the user and leaving it set toTrue
in this case would solve the issue.
I just set that flag for purposes of the example. I could have triggered that condition by changing the preprocessor to use the numpy interface, but I'm lazy :)
- Do you need normalisation with
StandardisationMeanStd*
before the sample gets to the model or could it be integrated in the preprocessing step? You can avoidStandardisationMeanStd*
being added automatically if you setpreprocessing=None
inPyTorchClassifier
, that way onlyReducingPreprocessor
will be applied.
No normalization is needed. And, yes, to solve this issue I set preprocessing=None
and implemented estimate_gradients
as you show.
In theory though, I don't see why it shouldn't work with preprocessing=(0., 1.)
since that normalization is essentially a no-op. But the code, as written, makes an additional assumption that preprocessors will always return the same shape as the input. It's a fine assumption, but I don't think that is enforced or documented anywhere and probably should be.
Maybe my title is a misnomer in that it does compute the gradient correctly but will not work with certain types of preprocessors?
@dxoigmn I agree with you that the implicit assumptions are not documented as well as they could be, we'll try to improve it. I think it calculates the gradients correctly, it just assumes that if the shapes are different then it must be class gradients with an additional dimension to loop over. We'll have to think of a solution distinguishing a case like yours from class gradients.
Describe the bug In PyTorchEstimator, the
_apply_preprocessing_gradient
does not appear to work correctly when a non-PyTorch preprocessor is used with a PyTorch estimator. In particular, when computing the gradient, the estimator uses the preprocessor'sestimate_gradient
but the estimator does not store intermediate represents to feed these calls. Rather, for each preprocessor, it always passes the original input (x
in the snippet below), which I believe is wrong: https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/24a616befec9d35ba053d16d4a6ac6cfc573039c/art/estimators/pytorch.py#L266-L273In many cases, a preprocessor usually just returns the same shape as the input, so this silently appears to work. We found this because our preprocessor changes the shape of the input (e.g., a preprocessor that collapses an RGB image or video into a grayscale image or video), and that loudly exposes the bug. Here's the (truncated) error with some custom debugging output:
Our preprocessor
oscar.defences.preprocessor.Preprocessor
takes input[1, 220, 240, 320, 3]
and turns it into output[1, 220, 240, 320, 1]
.To Reproduce I don't have a small test case to exhibit this problem, but I imagine one could create a non-pytorch preprocessor that collapses an image into grayscale input that exposes the issue and fails loudly.
Expected behavior A clear and concise description of what you expected to happen.
Screenshots N/A
System information (please complete the following information):