SeldonIO / alibi

Algorithms for explaining machine learning models
https://docs.seldon.io/projects/alibi/en/stable/
Other
2.41k stars 252 forks source link

[Design] Counterfactual instances via Reinforcement Learning #442

Closed jklaise closed 3 years ago

jklaise commented 3 years ago

This issue is for design and discussion of implementing the counterfactual method (from hereon referred to as CF-RL) from our recent paper Model-agnostic and Scalable Counterfactual Explanations via Reinforcement Learning.

Some potentially useful links:

RobertSamoilescu commented 3 years ago

A sketch of the CF-RL algorithm:

class CounterfactualRL(Explainer, FitMixin):
    def __init__(self,
                 enc: Union[keras.Sequential, nn.Sequential],
                 dec: Union[keras.Sequential, nn.Sequential],
                 actor: Union[keras.Sequential, nn.Sequential],
                 critic: Union[keras.Sequential, nn.Sequential]):

        # set auto-encoder components
        self.enc = enc
        self.dec = dec

        # set DDPG components
        self.actor = actor
        self.critic = critic

    def fit(self,
            X: np.ndarray,
            y: np.ndarray,
            train_steps: int,
            predict_func: Callable,
            backend_flag: str,
            update_every: int) -> "Explainer":
        """
        Fit the model agnostic counterfactual generator.
        TODO: add other DDPG parameters
        TODO: should those parameters be passed in fit or in the constructor?

        Parameters
        ----------
        X
            Training input data array.
        y
            Training labels.
        train_steps
            Number of steps to train the counterfactual generator.
        predict_func
            Prediction function. This corresponds to the classifier.
        backend_flag:
            Backend flag. Possible values: 'tf', 'pt'.
        update_every:
            Number of env interaction that should elapse between gradient updates.
            Regardless of how long you wait between updates, the ration of env steps
            to gradients steps is locked to 1.

        Returns
        -------
        self
            The explainer itself.
        """
        # Set backend according to the backend_flag.
        #  TODO: check if the packages are installed.
        Backend = TFCounterfactualRLBackend if backend_flag == "tf" else PTCounterfactualRLBackend

        # Define replay buffer (this will deal only with numpy arrays)
        replay_buff: ReplayBuff

        # Define reward function.
        reward_func: Callable

        # Define post-processing function (this should only be define for tabular data
        # TODO: see how to define it only for tabular datasets?
        # TODO: for tabular setting I think we can look for categorical mapping and other specific parameters.
        # TODO: could this be user specified or is it tide to the internal mechanics?
        postprocessing_func: Callable

        # Define data generator
        data_generator: DataGenerator

        # Training loop
        for i in range(train_steps):
            # Sample batch of x, random target and conditional vector.
            x, y_T, c = data_generator

            # compute model prediction.
            y_M = predict_func(x)

            # Compute embedding.
            z = self.enc(x)

            # Compute counterfactual embedding.
            z_CF = Backend.generate_cf(self.actor, z, y_M, y_T, c)

            # Add noise to the counterfactual embedding.
            z_CFt = Backend.add_noise(z_CF)

            # Decode counterfactual and apply postprocessing step to x_CF.
            # TODO: add if statement to perform conditioning only for tabular setting?
            # TODO: can we generalize to other modalities?
            x_CFt = postprocessing_func(self.dec(z_CFt).numpy(), c)

            # Compute reward. To compute reward, first we need to compute model's prediction
            # on the counterfactual generated.
            R = reward_func(predict_func(x_CFt), y_T)

            # Store experience in the replay buffer.
            replay_buff.store(x, z, y_M, y_T, c, z_CF, R)

            if i % update_every:
                for j in range(update_every):
                    # Sample batch of experience form the replay buffer.
                    z, y_M, y_T, c, z_CFt = replay_buff.sample()

                    # Update critic by one-step gradient descent.
                    Backend.update_critic(self.critic, z, y_M, y_T, c, z_CFt, R)

                    # Compute counterfactual embedding.
                    z_CF = Backend.generate_cf(self.actor, z, y_M, y_T, c)

                    # Compute counterfactual.
                    x_CF = self.dec(z_CF)

                    # Update actor
                    # TODO: more coeffs are needed (such as loss weights)
                    Backend.update_actor(self.actor, z, y_M, y_T, c, z_CF, x_CF, postprocessing_func, self.enc)

        return self

There are just a few places where the algorithm is dependent on the backend used. Thus, I was thinking to have two classes for each backend that look like:

class TFCounterfactualRLBackend(CounterfactualRLBackend):
    @staticmethod
    def generate_cf(actor: keras.Sequential,
                    z: tf.Tensor,
                    y_M: tf.Tensor,
                    y_T: tf.Tensor,
                    c: Optional[tf.Tensor]) -> tf.Tensor:
        pass

    @staticmethod
    def add_noise(z_CF: tf.Tensor) -> tf.Tensor:
        pass

    @staticmethod
    def update_critic(critic: keras.Sequential,
                      z: tf.Tensor,
                      y_M: tf.Tensor,
                      y_T: tf.Tensor,
                      c: tf.Tensor,
                      z_CFt: tf.Tensor,
                      R: tf.Tensor):
        pass

    @staticmethod
    def update_actor(actor: keras.Sequential,
                     z: tf.Tensor,
                     y_M: tf.Tensor,
                     y_T: tf.Tensor,
                     c: tf.Tensor,
                     z_CF: tf.Tensor,
                     x_CF: tf.Tensor,
                     postprocessing_func: Callable,
                     enc: keras.Sequential):
        pass
class PTCounterfactualRLBackend(CounterfactualRLBackend):
    @staticmethod
    def generate_cf(actor: nn.Sequential,
                    z: torch.Tensor,
                    y_M: torch.Tensor,
                    y_T: torch.Tensor,
                    c: Optional[torch.Tensor]) -> torch.Tensor:
        pass

    @staticmethod
    def add_noise(z_CF: torch.Tensor) -> torch.Tensor:
        pass

    @staticmethod
    def update_critic(critic: nn.Sequential,
                      z: torch.Tensor,
                      y_M: torch.Tensor,
                      y_T: torch.Tensor,
                      c: torch.Tensor,
                      z_CFt: torch.Tensor,
                      R: torch.Tensor):
        pass

    @staticmethod
    def update_actor(actor: nn.Sequential,
                     z: torch.Tensor,
                     y_M: torch.Tensor,
                     y_T: torch.Tensor,
                     c: torch.Tensor,
                     z_CF: torch.Tensor,
                     x_CF: torch.Tensor,
                     postprocessing_func: Callable,
                     enc: nn.Sequential):
        pass
jklaise commented 3 years ago

Looks like a good first step, I like that the core algorithm is backend-agnostic.

A few minor things;

Would logging calls be on the level of backends or on the level of the backend-agnostic code?

RobertSamoilescu commented 3 years ago
jklaise commented 3 years ago

Closing as this is implemented in #457.