albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.24k stars 235 forks source link

Fix `Perturbate` on RGB images #306

Closed adrhill closed 1 year ago

adrhill commented 1 year ago

Closes #299.

Perturbation functions are now applied separately to each color-channel.

perturbation_function Image
"mean" peppers_mean
"zeros" peppers_zeros
"gaussian" peppers_gaussian

These images were created using perturbate_on_batch:

import numpy as np
from PIL import Image
import tensorflow as tf

tf.compat.v1.disable_eager_execution()

from innvestigate.tools.perturbate import Perturbation

im = Image.open("peppers.tiff")
im = im.resize((224, 224), resample=0)

# Convert image to array
assert tf.keras.backend.image_data_format() == "channels_last"

def im2array(im):
    """Covert image to array in channels_last format."""
    x = np.array(im) / 255
    # x = np.moveaxis(x, 2, 0)  # for channels_first
    return np.reshape(x, (1, *(x.shape)))  # add batch dim

def array2im(a):
    """Convert array in channels_last format to image."""
    a = np.uint8(a[0, :, :, :] * 255)
    # a = np.moveaxis(a, 0, 2)  # for channels_first
    return Image.fromarray(a)

x = im2array(im)

# Random analysis:
a = np.random.rand(*x.shape)

# Create innvestigate's Perturbation
num_perturbed_regions = 8
perturbation_function = "mean"
region_shape = (32, 32)

p = Perturbation(
    perturbation_function,
    num_perturbed_regions=num_perturbed_regions,
    region_shape=region_shape,
    value_range=(0, 1),
)

# Perturb and show image
x_perturbated = p.perturbate_on_batch(x, a)
im_perturbated = array2im(x_perturbated)
im_perturbated.save(f"peppers_{perturbation_function}.png")