rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
66 stars 7 forks source link

LLaMA family issues #10

Open dvdblk opened 1 month ago

dvdblk commented 1 month ago

Hello again :) I have an issue with how the relevancy scores are computed for some LLaMA models for sequence classification.

I have a classification task that I am using in the following screenshots, all models in the screenshots correctly predicted the class, same attnlrp rules have been applied (lxt 0.6.0), models are trained on the same data splits and the example input is the same.

Everything from the LRP side works well with a finetuned SciBERT model. The model is finetuned without quantization and I have tried with or without lora, the relevancy scores are always faithful and look very plausible (good news)!

scibert

Roughly the same thing happens for TinyLLaMA (with lora or without lora, no quantization) although the punctuation marks (usually a dot .) seems to have more relevance than in the SciBERT case.

tiny_llama

However, as soon as I try to do AttnLRP with LLaMA-2 or LLaMA-3, I get very bad faithfulness scores and mostly the first and last token gets a high relevancy score.

llama2

llama3

Could you please provide some clues or leads on why this is happening? Considering the issue is less prominent in tinyllama, is the depth of the og LLaMA models the problem? I.e. the LRP propagation value lose meaning after so many layers..? Not sure where the issue is. Cheers.

rachtibat commented 1 month ago

Hey,

this is awesome and why we built this toolbox! I'm happy to see it applied in the wild!

Why does a heatmap not look like I expect? Attribution methods try to trace the model reasoning. So, if a heatmap looks unexpected, it could be due to two things:

  1. attribution method has low faithfulness or wrongly implemented
  2. the model has bad performance / does unexpected things!

Why does AttnLRP highlight punctuations? I think, this might be due to the phenomenon of attention sinks https://huggingface.co/blog/tomaarsen/attention-sinks#attention-sinks and also scratch pads https://arxiv.org/abs/2112.00114

I am also not sure how you applied LRP to the lora connector. There, you also have some addition, matrix multiplication etc. and this must be explained with LXT too. So, I would start "LRPfying" lora first. In addition, you should make sure that quantization does not replace linear layers with something else, because then LXT doesn't work anymore and this must be handled too like in the Mixtral example.

Hope it helps and we can find the problem

dvdblk commented 1 month ago

thanks for the prompt response :D some more info:

  1. llama3, llama2, tinyllama, scibert in order of performance (best to worst F1, accuracy, range 70-80% for both metrics) on the classification task, i.e. better performing models have lower faithfulness
  2. I have LRPfyied lora and tested it with the scibert model. It produces almost the same relevancy scores with or without lora. However, there is actually no need to implement lora in the LRP framework as you can just merge the lora weights back into the original model, getting rid of lora layers in the model before applying AttnLRP like so:
from peft import PeftModel
# model is loaded with AutoModelForSequenceClassification.from_pretrained and has lora layers
peft_model = PeftModel.from_pretrained(model, model_id=model_name)
model = peft_model.merge_and_unload()
attnlrp.register(model)
  1. "quantization does not replace linear layers" I have checked that this is not the case, as I am loading the finetuned models without quantization for inference / getting relevancy scores. Moreover, llama2 didn't have any quantization during training.

I have seen that in your paper you have used LLaMA-2 7B on the IMDB dataset. Did you fine-tune llamaforsequenceclassification without lora? Have you noticed similar behavior as for my classification task?

Besides attention sinks or scratch pads I could look into individual layers to see where it goes "wrong".

Tomsawyerhu commented 1 month ago

Could you please provide the test case?

rachtibat commented 1 month ago

Hey,

Thanks for the in-depth response! You often mention that AttnLRP has a low faithfulness in the finetuned llama 3 case. How did you assess faithfulness, is it because the heatmap does not match up with our human intuition or did you perform token perturbation experiments? I think the tokens such as "water" are still highlighted correctly, only the start and end tokens get a lot of relevance. I wonder, if you remove the start and end tokens from the heatmap, if it will look "normal" again?

We could try some things:

  1. Comparing against other attribution score to see if they also highlight the start and end tokens. Since the context size of your sample is quite small, you could benchmark against perturbation-based methods like https://captum.ai/tutorials/Llama2_LLM_Attribution.
  2. In my paper, we only finetuned the last linear layer of llama 2 llamaforsequenceclassification. This will result in lower faithfulness than lora probably, but you could also try this out!

Good to know, that the lora weights are merged into the linear layers, so we can exclude a possible wrong implementation of LXT here (assumed you do in fact merge the weights and did not attempt to "LRPfy" lora).

(Maybe posting here a small test script would be nice to check if lora introduces some non-LRP compatible operations if this makes sense and is possible.)

rachtibat commented 1 month ago
  1. we could also try to explain the softmax output. You could simply add:

output = lxt.functional.softmax(logits, temperature=2)

and try some different temperature scaling values