understandable-machine-intelligence-lab / Quantus

Quantus is an eXplainable AI toolkit for responsible evaluation of neural network explanations
https://quantus.readthedocs.io/
Other
512 stars 71 forks source link

Adding "true" batch implementation to metrics #350

Open davor10105 opened 1 week ago

davor10105 commented 1 week ago

Hey @annahedstroem , Thank you for creating such an excellent library. I've been using it extensively in my recent research, and it's been incredibly helpful.

While working with the library, I noticed that although most metric implementations support batch processing via the evaluate_batch method, they do not seem to be optimized for true batch processing. Instead, they iterate over each element in the batch, concatenating the final scores at the end.

I would love to contribute to your library by implementing true batch processing for these metrics. This optimization would significantly reduce evaluation times. Would you be open to this contribution?

Below, you will find code snippets showcasing my batch implementation of the evaluate_batch methods for the PixelFlipping and MaxSensitivity metrics. Additionally, I have included helper functions required for calculating these specific metrics. This implementation is a preliminary proof of concept, and with your approval, I would like to properly integrate these changes into the library.

Following the code snippets, I have included two plots that illustrate the performance gains achieved with these modified implementations. These results are based on 30 repetitions. As shown in the figures, the true batch version for PixelFlipping is 18 times faster than the current implementation, and for MaxSensitivity, the performance gain is 2.3 times.

# Batch PixelFlipping
def evaluate_batch(
        self,
        model: ModelInterface,
        x_batch: np.ndarray,
        y_batch: np.ndarray,
        a_batch: np.ndarray,
        **kwargs,
    ) -> List[Union[float, List[float]]]:
        a = a_batch.reshape(a_batch.shape[0], -1)

        # Get indices of sorted attributions (descending).
        a_indices = np.argsort(-a, axis=1)

        # Prepare lists.
        n_perturbations = math.ceil(a_indices.shape[-1] / self.features_in_step)
        preds = np.zeros((a_batch.shape[0], n_perturbations)) * np.nan
        x_perturbed = x_batch.copy()

        for perturbation_step_index in range(n_perturbations):
            # Perturb input by indices of attributions.
            perturb_indices = a_indices[:, perturbation_step_index * self.features_in_step: (perturbation_step_index + 1) * self.features_in_step]
            x_perturbed = batch_perturb_baseline(
                x_perturbed,
                perturb_indices=perturb_indices,
            )

            for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed):
                warn.warn_perturbation_caused_no_change(
                    x=x_instance,
                    x_perturbed=x_instance_perturbed,
                )

            # Predict on perturbed input x.
            predictions = model.predict(x_perturbed)
            y_pred_perturb = predictions[np.arange(x_perturbed.shape[0]), y_batch]

            # Save predictions
            preds[:, perturbation_step_index] = y_pred_perturb

        if self.return_auc_per_sample:
            return np.trapz(preds)

        return preds

# Batch MaxSensitivity
def evaluate_batch(
        self,
        model: ModelInterface,
        x_batch: np.ndarray,
        y_batch: np.ndarray,
        a_batch: np.ndarray,
        **kwargs,
    ) -> np.ndarray:
        batch_size = x_batch.shape[0]
        similarities = np.zeros((batch_size, self.nr_samples)) * np.nan

        for step_id in range(self.nr_samples):
            # Perturb input.
            batch_perturb_indices = np.tile(
                np.arange(0, x_batch[0].size), (batch_size, 1)
            )
            x_perturbed = batch_perturb_uniform(x_batch, self.lower_bound, batch_perturb_indices)

            changed_prediction_indices = self.changed_prediction_indices_func(
                model, x_batch, x_perturbed
            )

            for x_instance, x_instance_perturbed in zip(x_batch, x_perturbed):
                warn.warn_perturbation_caused_no_change(
                    x=x_instance,
                    x_perturbed=x_instance_perturbed,
                )

            # Generate explanation based on perturbed input x.
            a_perturbed = self.explain_batch(model, x_perturbed, y_batch)

            # Calculate metric
            batch_numerator = batch_fro_norm(a_batch - a_perturbed)
            batch_denominator = batch_fro_norm(a_batch)
            sensitivities = batch_numerator / batch_denominator
            similarities[:, step_id] = sensitivities

            # Mask changed predictions
            similarities[changed_prediction_indices, step_id] = np.nan

        return self.max_func(similarities, axis=1)

# Helper functions
def batch_perturb_uniform(x, lower_bound, perturb_indices: list[list[int]]):
    x_shape = x.shape
    x = x.reshape(x.shape[0], -1)
    noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=x.shape)
    mask = np.zeros_like(x)
    for i, perturb_index in enumerate(perturb_indices):
        mask[i, perturb_index] = 1.
    x = x + noise * mask
    x = x.reshape(*x_shape)
    return x

def batch_similarity_difference(a, b):
    return a - b

def batch_fro_norm(x):
    norm = np.linalg.norm(np.reshape(x, (x.shape[0], -1)), axis=1)
    return norm

def batch_perturb_baseline(x, perturb_indices: list[list[int]]):
    x_shape = x.shape
    x = x.reshape(x.shape[0], -1)
    mask = np.ones_like(x)
    for i, perturb_index in enumerate(perturb_indices):
        mask[i, perturb_index] = 0.0
    x = x * mask
    x = x.reshape(*x_shape)
    return x

Results for PixelFlipping: image

Results for MaxSensitivity: image

Please let me know your thoughts on this suggestion and if it would be alright for me to proceed with reimplementing the metrics.

Kind regards, Davor

davor10105 commented 5 days ago

Any thoughts on this? @annahedstroem @annariasdu