iancovert / sage

For calculating global feature importance using Shapley values.
MIT License
253 stars 34 forks source link

Use of SAGE in studying latent space dimensions #35

Closed CCranney closed 1 month ago

CCranney commented 1 month ago

Hi SAGE developers,

I’m reading a paper published in 2022 called Variational autoencoders learn transferrable representations of metabolomics data that uses SAGE software. I’m trying to understand how SAGE works as applied to their research, but am failing to connect on two fronts specifically. Reading their code, they made some outside-the-box decisions, and I am wondering if those applications work.

This is their code, as found here:

def get_sage_mets(model, ref_data, test_data, dim):

    def get_dim_vals(dat):
        return model.encode_mu(dat)[:,dim:(dim+1)]

    dim_output = get_dim_vals(test_data)

    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer = sage.MarginalImputer(ref_data[:10])
    sampler = sage.PermutationSampler(get_dim_vals, imputer, 'mse')
    sage_values = sampler(test_data, Y=dim_output, batch_size = 10)

    # rename dimension 0 to dimension 18
    dim_idx = 18 if dim == 0 else dim

    save_sages(sage_values, f"results/sage_values/VAE/met_dim_{dim_idx}.csv")

    return sage_values

My questions are about the following:

  1. In the paper, the authors are evaluating the impact of inputs on specific outputs of the latent space of a variational autoencoder. There are 18 dimensions in the latent space, and they run SAGE on each one individually. You see this in the get_dim_vals function above, where only the dim dimension is evaluated when run through SAGE. Note that they are not applying SAGE to the model as a whole – just a specific dimension of the latent layer. Would this work? It was my understanding that SAGE was meant to evaluate the model as a whole (global interpretability as opposed to local interpretability). Would effectively isolating SAGE to specific dimensions be an appropriate application, “localizing” the SAGE metrics to specific latent dimension outputs?
  2. Another evaluation choice they made was to feed the inputs through the encoder, then use the encoded output dim_output as ‘Y’ in the PermutationSampler. When reading through the SAGE documentation, it looks like the ‘Y’ value was supposed to be the actual model output. This would not be possible if you are evaluating a latent space output, so I understand why a workaround was used. I guess I’m wondering if this decision is a valid application of SAGE, and if so if our interpretation of the SAGE values would be impacted at all.

I’m liking the general approach the authors are using and am hoping these adjustments are valid. Latent space interpretation is an interesting topic, and I hope this application works. However, the variations are distinct enough from what I see in the documentation that I thought I’d check with the developers that this was a valid approach.

iancovert commented 1 month ago

Hi there, thanks for reaching out and happy to help think through this.

To summarize the code, it looks like what you're saying is right: the function calls SAGE on a single output of a multi-dimensional model, which I guess is the encoder from a AE/VAE.

As for the questions you brought up:

  1. I think it's a bit clunky to run separately on each output dimension, because it would be more efficient to re-use predictions across all output dims (getting many predictions with different data/subsets is the main bottleneck of SAGE). But mathematically, I don't have any qualms with this: if you did this analysis for all dimensions simultaneously, you would probably check the MSE of the predicted values $f(x)$ against the predictions with a subset of features $f(x_S)$, and the distance measure you'd probably use is the squared 2-norm $||f(x) - f(x_S)||^2$. This happens to be the sum of the MSE across all output dimensions $\sum_i (f_i(x) - f_i(x_S))^2$, so the sum of these per-dimension SAGE values is the same as the aggregate SAGE value that considers all dimensions. It's not obvious to me that it's worth investigating the SAGE values separately for each dimension, but I don't see anything especially wrong with it. For the record, you're right that SAGE is supposed to be global, but in the paper "global" means aggregate importance across many inputs/predictions, vs "local" which is for a single input/prediction (like SHAP).

  2. Using the full-input latent variables $f(x)$ as the $Y$ is interesting. There's a part in the SAGE paper where we talk about what to do when you don't have a $Y$, and the basic idea is that you can tell which features are capable of making your model really change its prediction. That intuition seems valid here, this analysis will basically find which inputs have the most influence over the latent representation. It seems reasonable to me, with the caveat that we don't know whether certain changes in the latent space lead to very different reconstructions; with that in mind, I could also imagine running SAGE where you use the entire AE/VAE to do input reconstruction, and check how your reconstruction error depends on each feature. But I'm not quite sure, maybe there's a good reason to care about feature importance separately for each latent dimension.

CCranney commented 1 month ago

This is perfect, thank you Ian! I appreciate the detailed response and theorizing. Correct me if I'm mistaken, but it sounds like the authors colored outside the lines a bit, but never in a way that compromises the integrity of SAGE values. I do like your suggestion about running SAGE on the entire AE/VAE to do input reconstruction - in a sense, you'd get a double metric for each input, enabling further understanding of the overall process.

I'm going to go ahead and close this issue. Thanks again for your help!

iancovert commented 1 month ago

Happy to help! And yep that's right, their usage of SAGE here is a bit different than what we intended, but it seems valid to me.