keisen / tf-keras-vis

Neural network visualization toolkit for tf.keras
https://keisen.github.io/tf-keras-vis-docs/
MIT License
311 stars 45 forks source link

Reconstruction Error Score? #92

Open inesws opened 1 year ago

inesws commented 1 year ago

Hi! Thanks for this XAI visualization package!

I am trying to use the saliency method applied to Autoencoders and reconstruction errors. Also, I do not have images, but a feature array for each observation. My goal is to check the "importance" of each input feature to the overall reconstruction, thus was trying to use smoothgrad method.

In this case, I do not understand what should I pass to 'score'. The output of the model is just the reconstructed input (n_samples x n_features).

I also couldn't pass a custom function to the score attribute as you suggest in the documentation. 'From the example in the repository:

Instead of using CategoricalScore object, you can also define the function from scratch as follows:

def score_function(output):

The output variable refers to the output of the model,

# so, in this case, `output` shape is `(3, 1000)` i.e., (samples, classes).
return (output[0][1], output[1][294], output[2][413])

, But then, how can you pass the function to the method? It needs to be callable, so it gives "ValueError: Score object must be callable! " Could you add an example explicitly using score_function passed to saliency() (or any other) instance, instead of the instance score from one of the defined score classes (BinaryScore, CategoricalScore) ?

Thank you in advance!

keisen commented 1 year ago

Hi, @inesws . I'm sorry I didn't make it clear enough.

You can pass the score_function to saliency() as below:

def score_function(output):
    return (output[0][1], output[1][294], output[2][413])

saliency = Saliency(model)
saliency_map = saliency(score_function, X)

Thanks!

kallydimitrova commented 1 year ago

Hi, @inesws, did you manage to figure out a custom score function for the autoencoder? I'm facing the same problem right now and so far no luck solving it...