Harry24k / adversarial-attacks-pytorch

PyTorch implementation of adversarial attacks [torchattacks].
https://adversarial-attacks-pytorch.readthedocs.io/en/latest/index.html
MIT License
1.79k stars 337 forks source link

[QUESTION] Can be used for segmentation tasks? #141

Closed MarlonGarcia closed 8 months ago

MarlonGarcia commented 1 year ago

Please, this library can be used for segmentation tasks?

If yes, what will be the size and content of label when the attack function is called (as seen below)?

Will it be (N, H, W) ?

atk = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10)
image_adv = atk(image, label)
rikonaka commented 1 year ago

Hi @MarlonGarcia , please describe what your input shape is, and explain what your purpose is. The information you provided is too little for us to give a specific reply.

MarlonGarcia commented 1 year ago

Thank you, @rikonaka, I will appreciate your help. I am trying to set adversarial attacks to attack a segmentation model. My label shape is of size (N,C,H,W), where N is the batch, C is the channel of the label (because my label is in one-hot encoding format, so only one channel is 1 at each pixel's position, the others being zero), H and W are the height and width of the image. First, I set for 'random target':

atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=10)
atk.set_mode_targeted_random()

Then I think I have to pass the true label to make atk instance to calculate the right adversarial example, don't I? So I used:

for data in loader:
    image, label = data
    image_adv = atk(image, label)

But this rises the following error:

  File "C:\Users\marlo\anaconda3\envs\environment\lib\site-packages\torchattacks\attack.py", line 434, in get_random_target_label
    l.remove(labels[counter])

RuntimeError: Boolean value of Tensor with more than one value is ambiguous
rikonaka commented 1 year ago

Hi @MarlonGarcia , very strange error, I can't reproduce this error at all.

Can you show the shape of your image and label variable?

My guess is that you are using one-hot encoding in your label, you may want to try converting the label to standard form.

for data in loader:
    image, label = data
    print(image.shape) # should be [n, 3, 32, 32] in CIFAR10
    print(label.shape) # should be [n] in CIFAR10
    # image_adv = atk(image, label)

image