NorskRegnesentral / shapr

Explaining the output of machine learning models with more accurately estimated Shapley values
https://norskregnesentral.github.io/shapr/
Other
147 stars 34 forks source link

Asymmetric causal Shapley values with adaptive sampling #400

Closed LHBO closed 1 month ago

LHBO commented 1 month ago

Extension of PR #395 that builds on the adaptive sampling introduced in PR #396. As the latter PR completely rewrote all of the main functions in shapr, I found it easier to start on a new one as there were a lot of merge conflicts when trying to merge #396 into #395.

In this PR, we add support for computing asymmetric and/or causal Shapley values. The asymmetric version can use all approaches, while the causal version is limited to the Monte Carlo-based approaches. The implementation is an extension of #273 (but this PR was restricted to the gaussian approach and the old version of shapr), which was adapted from the package CauSHAPley.

Asymmetric Shapley values were proposed by Frye et al. (2020) as a way to incorporate causal knowledge in the real world by restricting the possible permutations of the features when computing the Shapley values to those consistent with a (partial) causal ordering.

Causal Shapley values were proposed by Heskes et al. (2020) as a way to explain the total effect of features on the prediction, taking into account their causal relationships, by adapting the sampling procedure in shapr.

The two ideas can be combined to obtain asymmetric causal Shapley values. If you would like more details, you can see Heskes et al. (2020).

Usage: (Assume N_features = 7) (Symmetric) Conditional Shapley values: asymmetric = FALSE (default), causal_ordering = list(1:7) (default), and confounding = FALSE (default)

Marginal Shapley values: either 1) the same as above, but set approach = independence, or 2) asymmetric = FALSE (default), causal_ordering = list(1:7) (default), and confounding = TRUE.

Asymmetric conditional Shapley values with respect to a specific ordering: asymmetric = TRUE, causal_ordering = list(1, c(2, 3), 4:7), and confounding = FALSE (default).

Causal Shapley values (compute all coalitions, but chains of sampling steps): asymmetric = FALSE (default), causal_ordering = list(1, c(2, 3), 4:7), andconfounding = c(FALSE, TRUE, FALSE).

Asymmetric Causal Shapley values (compute only coalitions respecting the ordering and chains of sampling steps): asymmetric = TRUE, causal_ordering = list(1, c(2, 3), 4:7), and confounding = c(FALSE, TRUE, FALSE).

Main differences: The user now has the option to specify asymmetric, causal_ordering, and confounding in the explain function. The first argument, asymmetric, specifies if we should consider all feature combinations/coalitions, or only the combinations that respect the (partial) causal ordering given by causal_ordering. The second argument, causal_ordering is a list specifying the (partial) causal ordering of the features (groups), i.e., causal_ordering = list(1:3, 4:5), which implies that features one to three are the ancestors of four and five. The third argument, confounding specifies if the user assumes that each component is subject to confounding or not, e.g., causal_confounding = c(FALSE, TRUE). So that you know, practitioners are responsible for correctly identifying the causal structures.

When the causal_ordering is not list(1:N_features), then we have a causal structure that implies that some coalitions/feature combinations will not respect the order. For example, we cannot have a combination that conditions/includes feature four and not all of the features one to three in the setting above, as they are feature four's ancestors. If asymmetric = TRUE, then we only use the combinations that respect the order. If asymmetric = FALSE, then we use all combinations. Furthermore, generating the MC samples for each valid coalition will introduce a chain of sampling steps, which the confounding argument influences.

That is, if S = {2}, we would in the first step (assuming confounding = c(FALSE, TRUE)) sample X1, X3 | X2, and in the second step, we would sample X4, X5 | X1, X2, X3. The confounding changes whether to include the features in the same component as conditional features or not, as Heskes et al. (2020) explained. Also, see examples in get_S_causal() for demonstrations of how changing the confounding assumption changes the data generation steps.

To reuse most of the shapr code, we iteratively call prepare_data() with different values of S to generate the data. This introduces a lot of redundant computations, as we then generate X1, X3, X4, X5 | X2 in the first step, but throw away X4 and X5. To only generate MC samples for the relevant features, we would have to rewrite all prepare_data.approach functions also to take in a Sbar argument as they currently assume that Sbar is all features not in S.

The independence, empirical, and ctree approaches can not necessarily generate n_samples but rather weigh the samples. Combining these weights in an interactive sampling process is not obvious. We solve it by sampling the samples n_samples time using the weights. This means that we will have duplicates, which introduces extra computations.

Plot: Additionally, we have introduced the include_group_feature_means = FALSE argument in plot.shapr and plot_SV_several_approaches as some plots need to have a feature value, which we do not have for group-wise Shapley values. When TRUE, we use the average feature value among the features in each group (for each explicand). In plot_SV_several_approaches, we also add the index_explicands_sort argument to decide the order of the explicands in the plots.

TODO:

FUTURE:

References:

LHBO commented 1 month ago

Ready to merge into #402 when discussed with @martinju

martinju commented 1 month ago

We're aware that the mac-os test fails for the categorical approach, where one observation gets a different shapley value, due to a change in dt_vS value of v(S={3}). Further debugging is required to figure this out, but it is not straight forward as it only happens on the GHA macOS version, not locally on mac. Would probably need to add a new test which with keep_vS_output = TRUE to figure out more precisely where this is occuring.