rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
101 stars 12 forks source link

Making Llama 3 quantized work #14

Closed aymeric-roucher closed 2 weeks ago

aymeric-roucher commented 1 month ago

Do you have examples for working with a quantized llama3? I'm trying with

from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.bfloat16,
)

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="cuda", use_safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
attnlrp.register(model)

But then I get nans for the relevances, whereas the TinyLlama model given in the doc works just fine.

rachtibat commented 1 month ago

Hi @aymeric-roucher,

I think that the TinyLlama model is also not working correctly in quantized version. This is due to the fact that BitsAndBytes replaces all nn.Linear layers with a Linear8bitLt layer. This means, that we must specify in the composite that we'd like to apply the EpsilonRule on this layer by writing Linear8bitLt: rules.EpsilonRule as specified in https://github.com/rachtibat/LRP-eXplains-Transformers/blob/66221eddb7fb932e299f906a261feab8f1b9581e/lxt/models/mixtral.py#L1245

Unfortunately, in the lxt.models.llama.py file, I manually replaced all nn.Linear layers with lm.LinearEpsilon like in https://github.com/rachtibat/LRP-eXplains-Transformers/blob/66221eddb7fb932e299f906a261feab8f1b9581e/lxt/models/llama.py#L245C26-L245C42

So, to make it work:

  1. we need to replace all lm.LinearEpsilon in the llama.py file with nn.Linear
  2. add Linear8bitLt: rules.EpsilonRule to the attnlrp composite at line https://github.com/rachtibat/LRP-eXplains-Transformers/blob/66221eddb7fb932e299f906a261feab8f1b9581e/lxt/models/llama.py#L63
  3. add nn.Linear: rules.EpsilonRule to the attnlrp composite only to make sure that the unquantized version is also supported.

Right now, I am on vacation, but when I am back in two weeks, I can push an updated version.

I hope it helps, Reduan

rachtibat commented 2 weeks ago

Hey @aymeric-roucher,

you can find now a quantized llama example at https://github.com/rachtibat/LRP-eXplains-Transformers/tree/main/examples for the new release https://github.com/rachtibat/LRP-eXplains-Transformers/releases/tag/v0.6.1.

Have fun with it!

aymeric-roucher commented 2 weeks ago

This is great, thanks @rachtibat !