Closed Yuuuumie closed 1 year ago
Hi @Yuuuumie, we'll get this integrated into the training script soon, but here's the main function to use for now. The algorithm is fairly simple, we're just looking at the activation of the input data on the network. We add noise as this helps add robustness, but can be seen as an adversarial attack, beneficial to what we want and is typical since we train with heavy augmentations.
import torch
from torchvision.transforms.functional import to_pil_image
# These will need to change depending on the dataset you use. Best practice is to calculate since online
# resources can sometimes provide bad data (often you'll see ImageNet results and not your specific dataset)
IMAGENET_DEFAULT_MEAN = torch.Tensor([0.485, 0.456, 0.406])
IMAGENET_DEFAULT_STD = torch.Tensor([0.229, 0.224, 0.225])
def batch_salient(model,
imgs,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
rounds=100,
noise_std=0.1,
noise_mean=0,
):
for i in range(rounds + 1):
noise = torch.randn(imgs.size()) * noise_std + noise_mean
noise = noise.to(imgs.device)
imgs.requires_grad_()
salient = None
# Book-keeping so we maintain the root image for super-imposing the salient onto the original image
if i == 0:
preds = model(imgs)
preds_orig = preds.clone()
else:
preds = model(imgs + noise)
scores, indices = torch.max(preds, dim=1)
scores.backward(torch.ones_like(scores))
if salient is None:
salient = torch.max(imgs.grad.data, dim=1)[0]
else:
salient += torch.max(imgs.grad.data, dim=1)[0]
# This next line is optional and just normalizes the result.
salient /= rounds
salient.relu_()
salients = [to_pil_image(s.cpu().squeeze(0)).convert("RGB") for s in salient]
return preds_orig, salients
I'll close this comment when we add the code to the repo but I hope this is useful for now.
Hi @Yuuuumie, we'll get this integrated into the training script soon, but here's the main function to use for now. The algorithm is fairly simple, we're just looking at the activation of the input data on the network. We add noise as this helps add robustness, but can be seen as an adversarial attack, beneficial to what we want and is typical since we train with heavy augmentations.
import torch from torchvision.transforms.functional import to_pil_image # These will need to change depending on the dataset you use. Best practice is to calculate since online # resources can sometimes provide bad data (often you'll see ImageNet results and not your specific dataset) IMAGENET_DEFAULT_MEAN = torch.Tensor([0.485, 0.456, 0.406]) IMAGENET_DEFAULT_STD = torch.Tensor([0.229, 0.224, 0.225]) def batch_salient(model, imgs, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, rounds=100, noise_std=0.1, noise_mean=0, ): for i in range(rounds + 1): noise = torch.randn(imgs.size()) * noise_std + noise_mean noise = noise.to(imgs.device) imgs.requires_grad_() salient = None # Book-keeping so we maintain the root image for super-imposing the salient onto the original image if i == 0: preds = model(imgs) preds_orig = preds.clone() else: preds = model(imgs + noise) scores, indices = torch.max(preds, dim=1) scores.backward(torch.ones_like(scores)) if salient is None: salient = torch.max(imgs.grad.data, dim=1)[0] else: salient += torch.max(imgs.grad.data, dim=1)[0] # This next line is optional and just normalizes the result. salient /= rounds salient.relu_() salients = [to_pil_image(s.cpu().squeeze(0)).convert("RGB") for s in salient] return preds_orig, salients
I'll close this comment when we add the code to the repo but I hope this is useful for now.
Thank you for your reply! I will try it.
Can you share your method or code which you used to draw the salient map? Thanks a lot!