xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
806 stars 79 forks source link

How to generate frequency-based noise #44

Closed JFz0419 closed 2 months ago

JFz0419 commented 2 months ago

How to generate frequency-based noise

xxxnell commented 2 months ago

Hi @JFz0419,

Thank you for reaching out. One approach is to mask out random noise in the frequency space, keeping only a specific frequency. To achieve this, we can apply the FFT to the random noise, isolate the desired frequency, and then apply the inverse FFT. The following figure shows an example of frequency-based noise.

image

The following is an example pseudo code to do it.

import torch.nn.functional as F

class Random:

    def __init__(self, model=None, *, eps=0.7, sign=True, gpu=True):
        super().__init__()
        self.eps = eps
        self.gpu = gpu
        self.sign = sign

    def __call__(self, xs):
        xs = xs.clone().detach()

        if self.gpu:
            xs = xs.cuda()

        b, c, h, w = xs.shape
        random = torch.randn([b, c, h, w])
        random = random.to(xs.device)
        random = random.sign() if self.sign else random

        xs_adv = xs + self.eps * random
        xs_adv = xs_adv.detach()

        return xs_adv

class FreqAttack:

    def __init__(self, attack, *, f, s=0.2):
        super().__init__()
        self.attack = attack
        self.f = f
        self.s = s

    def __call__(self, xs):
        xs_adv = self.attack(xs)
        xs_adv = xs + self._fourier_mask(xs_adv - xs, self.f, self.s).real

        return xs_adv

    def _fourier_mask(self, x, f, s):
        b, c, h, w = x.shape

        # A. FFT
        x = torch.fft.fft2(x)
        x = self._shift(x)
        x_abs = x.abs()
        x_ang = x.angle()

        # B. Mask
        mask1 = self._center_mask(int(((f + s) * h) / (2 * math.pi)) * 2, h)
        mask2 = self._center_mask(int(((f - s) * h) / (2 * math.pi)) * 2, h)
        mask = mask1 - mask2
        mask = mask.to(x_abs.device)
        x_abs = mask * x_abs

        # C. Inverse FFT
        unit = torch.complex(torch.zeros(b, c, h, w), torch.ones(b, c, h, w))
        unit = unit.to(x.device)
        x = x_abs * torch.exp(unit * x_ang)

        x = self._shift(x)
        x = torch.fft.ifft2(x)

        return x

    def _shift(self, x):
        b, c, h, w = x.shape
        x = torch.roll(x, shifts=(int(h / 2), int(w / 2)), dims=(2, 3))
        return x

    def _center_mask(self, w1, w2):
        w1 = w2 if w1 > w2 else w1
        w1 = 0 if w1 < 0 else w1
        mask = torch.ones([1, 3, w1, w1])
        mask = F.pad(mask, [int((w2 - w1) / 2)] * 4)

        return mask
# This cell build off https://github.com/facebookresearch/mae
import requests
import torch
import numpy as np

from PIL import Image
from einops import rearrange, reduce, repeat
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

imagenet_mean = np.array(IMAGENET_DEFAULT_MEAN)
imagenet_std = np.array(IMAGENET_DEFAULT_STD)

# load a sample ImageNet-1K image -- use the full val dataset for precise results
xs = [
    "https://user-images.githubusercontent.com/930317/158025258-e9a5a454-99de-4d22-bc93-b217cdf06abb.jpeg",
]
xs = [Image.open(requests.get(x, stream=True).raw) for x in xs]
xs = [x.resize((224, 224)) for x in xs]
xs = [np.array(x) / 255. for x in xs]
xs = np.stack(xs)

assert xs.shape[1:] == (224, 224, 3)

# normalize by ImageNet mean and std
xs = xs - imagenet_mean
xs = xs / imagenet_std
xs = rearrange(torch.tensor(xs, dtype=torch.float32), "b h w c -> b c h w")
import math
import torch
import matplotlib.pyplot as plt
from einops import rearrange, reduce, repeat

fig, axes = plt.subplots(1, 4, figsize=(18, 4), dpi=200)

attack_base = Random(eps=0.7, gpu=False, )
attack1 = FreqAttack(attack_base, f=0.1*math.pi, s=0.05*math.pi)
attack2 = FreqAttack(attack_base, f=0.6*math.pi, s=0.05*math.pi)

xs1 = attack1(xs)
xs2 = attack2(xs)

axes[0].imshow(rearrange(xs, "b c h w -> b h w c")[0])  # clean image
axes[1].imshow(rearrange(xs1 - xs, "b c h w -> b h w c")[0])  # frequency based random noise (f = 0.1 pi)
axes[2].imshow(rearrange(xs2 - xs, "b c h w -> b h w c")[0])  # frequency based random noise (f = 0.6 pi)
axes[3].imshow(rearrange(xs2, "b c h w -> b h w c")[0])  # image with noise (f = 0.6 pi)

titles = [
    "Clean image",
    "Random noise of $f = 0.1 \pi$",
    "Random noise of $f = 0.6 \pi$",
    "Image with noise of $f = 0.6 \pi$",
]
for ax, title in zip(axes, titles):
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])

plt.show()

Please let me know and leave a message in my inbox if you need additional code snippets to measure robustness against frequency-based noise.