Closed jklaise closed 3 years ago
A sketch of the CF-RL
algorithm:
class CounterfactualRL(Explainer, FitMixin):
def __init__(self,
enc: Union[keras.Sequential, nn.Sequential],
dec: Union[keras.Sequential, nn.Sequential],
actor: Union[keras.Sequential, nn.Sequential],
critic: Union[keras.Sequential, nn.Sequential]):
# set auto-encoder components
self.enc = enc
self.dec = dec
# set DDPG components
self.actor = actor
self.critic = critic
def fit(self,
X: np.ndarray,
y: np.ndarray,
train_steps: int,
predict_func: Callable,
backend_flag: str,
update_every: int) -> "Explainer":
"""
Fit the model agnostic counterfactual generator.
TODO: add other DDPG parameters
TODO: should those parameters be passed in fit or in the constructor?
Parameters
----------
X
Training input data array.
y
Training labels.
train_steps
Number of steps to train the counterfactual generator.
predict_func
Prediction function. This corresponds to the classifier.
backend_flag:
Backend flag. Possible values: 'tf', 'pt'.
update_every:
Number of env interaction that should elapse between gradient updates.
Regardless of how long you wait between updates, the ration of env steps
to gradients steps is locked to 1.
Returns
-------
self
The explainer itself.
"""
# Set backend according to the backend_flag.
# TODO: check if the packages are installed.
Backend = TFCounterfactualRLBackend if backend_flag == "tf" else PTCounterfactualRLBackend
# Define replay buffer (this will deal only with numpy arrays)
replay_buff: ReplayBuff
# Define reward function.
reward_func: Callable
# Define post-processing function (this should only be define for tabular data
# TODO: see how to define it only for tabular datasets?
# TODO: for tabular setting I think we can look for categorical mapping and other specific parameters.
# TODO: could this be user specified or is it tide to the internal mechanics?
postprocessing_func: Callable
# Define data generator
data_generator: DataGenerator
# Training loop
for i in range(train_steps):
# Sample batch of x, random target and conditional vector.
x, y_T, c = data_generator
# compute model prediction.
y_M = predict_func(x)
# Compute embedding.
z = self.enc(x)
# Compute counterfactual embedding.
z_CF = Backend.generate_cf(self.actor, z, y_M, y_T, c)
# Add noise to the counterfactual embedding.
z_CFt = Backend.add_noise(z_CF)
# Decode counterfactual and apply postprocessing step to x_CF.
# TODO: add if statement to perform conditioning only for tabular setting?
# TODO: can we generalize to other modalities?
x_CFt = postprocessing_func(self.dec(z_CFt).numpy(), c)
# Compute reward. To compute reward, first we need to compute model's prediction
# on the counterfactual generated.
R = reward_func(predict_func(x_CFt), y_T)
# Store experience in the replay buffer.
replay_buff.store(x, z, y_M, y_T, c, z_CF, R)
if i % update_every:
for j in range(update_every):
# Sample batch of experience form the replay buffer.
z, y_M, y_T, c, z_CFt = replay_buff.sample()
# Update critic by one-step gradient descent.
Backend.update_critic(self.critic, z, y_M, y_T, c, z_CFt, R)
# Compute counterfactual embedding.
z_CF = Backend.generate_cf(self.actor, z, y_M, y_T, c)
# Compute counterfactual.
x_CF = self.dec(z_CF)
# Update actor
# TODO: more coeffs are needed (such as loss weights)
Backend.update_actor(self.actor, z, y_M, y_T, c, z_CF, x_CF, postprocessing_func, self.enc)
return self
There are just a few places where the algorithm is dependent on the backend used. Thus, I was thinking to have two classes for each backend that look like:
class TFCounterfactualRLBackend(CounterfactualRLBackend):
@staticmethod
def generate_cf(actor: keras.Sequential,
z: tf.Tensor,
y_M: tf.Tensor,
y_T: tf.Tensor,
c: Optional[tf.Tensor]) -> tf.Tensor:
pass
@staticmethod
def add_noise(z_CF: tf.Tensor) -> tf.Tensor:
pass
@staticmethod
def update_critic(critic: keras.Sequential,
z: tf.Tensor,
y_M: tf.Tensor,
y_T: tf.Tensor,
c: tf.Tensor,
z_CFt: tf.Tensor,
R: tf.Tensor):
pass
@staticmethod
def update_actor(actor: keras.Sequential,
z: tf.Tensor,
y_M: tf.Tensor,
y_T: tf.Tensor,
c: tf.Tensor,
z_CF: tf.Tensor,
x_CF: tf.Tensor,
postprocessing_func: Callable,
enc: keras.Sequential):
pass
class PTCounterfactualRLBackend(CounterfactualRLBackend):
@staticmethod
def generate_cf(actor: nn.Sequential,
z: torch.Tensor,
y_M: torch.Tensor,
y_T: torch.Tensor,
c: Optional[torch.Tensor]) -> torch.Tensor:
pass
@staticmethod
def add_noise(z_CF: torch.Tensor) -> torch.Tensor:
pass
@staticmethod
def update_critic(critic: nn.Sequential,
z: torch.Tensor,
y_M: torch.Tensor,
y_T: torch.Tensor,
c: torch.Tensor,
z_CFt: torch.Tensor,
R: torch.Tensor):
pass
@staticmethod
def update_actor(actor: nn.Sequential,
z: torch.Tensor,
y_M: torch.Tensor,
y_T: torch.Tensor,
c: torch.Tensor,
z_CF: torch.Tensor,
x_CF: torch.Tensor,
postprocessing_func: Callable,
enc: nn.Sequential):
pass
Looks like a good first step, I like that the core algorithm is backend-agnostic.
A few minor things;
__init__
, e.g. predictor_func
and backend
keras.Model
and nn.module
.y
are necessary?DDPG
parameters should probably be set at __init__
to some defaults and allow user to override via kwargs
(or a specific kwarg
just for DDPG
parameter dictionary if we anticipate more types of parameters for customizing the whole object?)Would logging calls be on the level of backends or on the level of the backend-agnostic code?
y
are not necessary.lambda x: x
and include in alibi
the postprocessing steps we used in the paper for tabular datasets. Closing as this is implemented in #457.
This issue is for design and discussion of implementing the counterfactual method (from hereon referred to as CF-RL) from our recent paper Model-agnostic and Scalable Counterfactual Explanations via Reinforcement Learning.
Some potentially useful links:
fit
-bound methods the architecture could be much simpler, e.g. look at Alibi Detect drift detection and models packages