Closed JFz0419 closed 2 months ago
Hi @JFz0419,
Thank you for reaching out. One approach is to mask out random noise in the frequency space, keeping only a specific frequency. To achieve this, we can apply the FFT to the random noise, isolate the desired frequency, and then apply the inverse FFT. The following figure shows an example of frequency-based noise.
The following is an example pseudo code to do it.
import torch.nn.functional as F
class Random:
def __init__(self, model=None, *, eps=0.7, sign=True, gpu=True):
super().__init__()
self.eps = eps
self.gpu = gpu
self.sign = sign
def __call__(self, xs):
xs = xs.clone().detach()
if self.gpu:
xs = xs.cuda()
b, c, h, w = xs.shape
random = torch.randn([b, c, h, w])
random = random.to(xs.device)
random = random.sign() if self.sign else random
xs_adv = xs + self.eps * random
xs_adv = xs_adv.detach()
return xs_adv
class FreqAttack:
def __init__(self, attack, *, f, s=0.2):
super().__init__()
self.attack = attack
self.f = f
self.s = s
def __call__(self, xs):
xs_adv = self.attack(xs)
xs_adv = xs + self._fourier_mask(xs_adv - xs, self.f, self.s).real
return xs_adv
def _fourier_mask(self, x, f, s):
b, c, h, w = x.shape
# A. FFT
x = torch.fft.fft2(x)
x = self._shift(x)
x_abs = x.abs()
x_ang = x.angle()
# B. Mask
mask1 = self._center_mask(int(((f + s) * h) / (2 * math.pi)) * 2, h)
mask2 = self._center_mask(int(((f - s) * h) / (2 * math.pi)) * 2, h)
mask = mask1 - mask2
mask = mask.to(x_abs.device)
x_abs = mask * x_abs
# C. Inverse FFT
unit = torch.complex(torch.zeros(b, c, h, w), torch.ones(b, c, h, w))
unit = unit.to(x.device)
x = x_abs * torch.exp(unit * x_ang)
x = self._shift(x)
x = torch.fft.ifft2(x)
return x
def _shift(self, x):
b, c, h, w = x.shape
x = torch.roll(x, shifts=(int(h / 2), int(w / 2)), dims=(2, 3))
return x
def _center_mask(self, w1, w2):
w1 = w2 if w1 > w2 else w1
w1 = 0 if w1 < 0 else w1
mask = torch.ones([1, 3, w1, w1])
mask = F.pad(mask, [int((w2 - w1) / 2)] * 4)
return mask
# This cell build off https://github.com/facebookresearch/mae
import requests
import torch
import numpy as np
from PIL import Image
from einops import rearrange, reduce, repeat
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
imagenet_mean = np.array(IMAGENET_DEFAULT_MEAN)
imagenet_std = np.array(IMAGENET_DEFAULT_STD)
# load a sample ImageNet-1K image -- use the full val dataset for precise results
xs = [
"https://user-images.githubusercontent.com/930317/158025258-e9a5a454-99de-4d22-bc93-b217cdf06abb.jpeg",
]
xs = [Image.open(requests.get(x, stream=True).raw) for x in xs]
xs = [x.resize((224, 224)) for x in xs]
xs = [np.array(x) / 255. for x in xs]
xs = np.stack(xs)
assert xs.shape[1:] == (224, 224, 3)
# normalize by ImageNet mean and std
xs = xs - imagenet_mean
xs = xs / imagenet_std
xs = rearrange(torch.tensor(xs, dtype=torch.float32), "b h w c -> b c h w")
import math
import torch
import matplotlib.pyplot as plt
from einops import rearrange, reduce, repeat
fig, axes = plt.subplots(1, 4, figsize=(18, 4), dpi=200)
attack_base = Random(eps=0.7, gpu=False, )
attack1 = FreqAttack(attack_base, f=0.1*math.pi, s=0.05*math.pi)
attack2 = FreqAttack(attack_base, f=0.6*math.pi, s=0.05*math.pi)
xs1 = attack1(xs)
xs2 = attack2(xs)
axes[0].imshow(rearrange(xs, "b c h w -> b h w c")[0]) # clean image
axes[1].imshow(rearrange(xs1 - xs, "b c h w -> b h w c")[0]) # frequency based random noise (f = 0.1 pi)
axes[2].imshow(rearrange(xs2 - xs, "b c h w -> b h w c")[0]) # frequency based random noise (f = 0.6 pi)
axes[3].imshow(rearrange(xs2, "b c h w -> b h w c")[0]) # image with noise (f = 0.6 pi)
titles = [
"Clean image",
"Random noise of $f = 0.1 \pi$",
"Random noise of $f = 0.6 \pi$",
"Image with noise of $f = 0.6 \pi$",
]
for ax, title in zip(axes, titles):
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
Please let me know and leave a message in my inbox if you need additional code snippets to measure robustness against frequency-based noise.
How to generate frequency-based noise