jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
61 stars 15 forks source link

Shapley Values on Attention Heads or Causal Edges Via Ablation #79

Open jbloomAus opened 1 year ago

jbloomAus commented 1 year ago

Basic concept is that we can sample from which heads we actually compute randomly in order to see which matter. Shapley values are usually computed over all subsets of heads. The total number of subsets is 2n. So if I have 24 heads, that's 224 = 16777216. It's clear we can't calculate these for larger models. One question we might ask is, how accurately could be estimate the shapley values of any given head given we randomly subsample.

Paper: https://arxiv.org/pdf/2210.05709.pdf https://arxiv.org/pdf/2205.12672.pdf https://github.com/slundberg/shap

GPT4: Shapley Values, derived from cooperative game theory, have found considerable application in the domain of machine learning to quantify the contribution of each feature towards the prediction of a particular instance. For context, attention is a concept that allows neural networks, particularly in NLP models like transformers, to focus on specific parts of the input when generating output. Each 'attention head' in these models can be thought of as a different part of the model that's learning to pay attention to different aspects of the input data.

Applying Shapley values to assess the importance of attention heads in a model can be insightful, as it helps assign an importance score to each attention head in terms of how it contributes to the final output. When combined with ablation studies (which involve systematically removing or 'ablating' parts of the model to see how performance changes), these can provide a more comprehensive understanding of the importance of each attention head.

However, a few points should be kept in mind:

  1. Computational complexity: Computing Shapley values can be computationally expensive, especially for large models with many attention heads. This might not be practical in all scenarios.

  2. Interpretability: While Shapley values can provide some level of interpretability, it's important to remember that deep learning models, especially those with attention mechanisms, are complex. Even if an attention head has a high Shapley value, it might not directly relate to human-understandable reasoning.

  3. Different aspects of attention: Attention heads often pay attention to different aspects of the data. An attention head with a low Shapley value isn't necessarily unimportant - it might just be capturing a less frequent, but still important, feature in the data.

  4. Ablation complexity: While ablation can show the importance of an attention head, removing an attention head can cause the model to change how it uses its remaining heads, leading to potential misinterpretations about the original role of the removed head.

Overall, while Shapley values and ablation studies can be helpful tools for understanding the importance of attention heads, they should be used with care and in conjunction with other interpretability methods to get a comprehensive understanding of the model's behavior.