rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
66 stars 7 forks source link

Classification tasks example #4

Closed dvdblk closed 2 months ago

dvdblk commented 2 months ago

Hello, could you please provide a simple example on how to use LlamaForSequenceClassification on classification tasks?

I am using a finetuned llama-3 (LlamaForSequenceClassification) to classify some text into 17 classes. However, using the TinyLLaMA code in the documentation doesn't seem to provide any meaningful relevance scores despite the prediction being correct.

rachtibat commented 2 months ago

Hey @dvdblk,

I can show you an example for the BERT model that @pkhdipraja just added. For that please pull the newest version of LXT from github and install it via pip -e install ./lxt.

The "textattack/bert-base-uncased-CoLA" model is trained to predict, whether a sentence is grammatically correct or wrong. The sentence 'I are a student' is wrong because of the 'are'. You will see this in the heatmap.

import torch
from transformers import AutoTokenizer
from lxt.models.bert import attnlrp, BertForSequenceClassification
from lxt.utils import pdf_heatmap, clean_tokens

def clean_wordpiece_split(tokens):
        """ BERT-specific cleaning. Workaround not working perfect yet."""
        return ["▁" + word.replace("##", "") for word in tokens]

def seq_cls():
    """AttnLRP for BERT sequence classification task."""
    tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-CoLA")
    model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-CoLA").to(torch.device("cuda"))
    model.eval()

    # apply AttnLRP rules
    attnlrp.register(model)

    inputs = "I are a student."

    input_ids = tokenizer(inputs, return_tensors="pt").input_ids.to(torch.device("cuda"))
    inputs_embeds = model.bert.get_input_embeddings()(input_ids)

    logits = model(inputs_embeds=inputs_embeds.requires_grad_()).logits

    # We explain the sequence label: acceptable or unacceptable
    max_logits, max_indices = torch.max(logits, dim=-1)

    out = model.config.id2label[max_indices.item()]
    print("The label of the sequence is: ", out)

    max_logits.backward(max_logits)

    relevance = inputs_embeds.grad.float().sum(-1).cpu()[0]
    # normalize relevance between [-1, 1] for plotting
    relevance = relevance / relevance.abs().max()

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    tokens = clean_tokens(clean_wordpiece_split(tokens))

    pdf_heatmap(tokens, relevance, path="./heatmap_seq_cls.pdf", backend="xelatex")

EDITED: changed the line from lxt.models.bert import attnlrp, BertForSequenceClassification

fallcat commented 2 months ago

Hi! I'm trying to use this example, but got some bug. It says LayerNorm can't take 6 arguments:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_2184/47476406.py in <module>
----> 1 seq_cls()

/tmp/ipykernel_2184/3838064053.py in seq_cls()
     13     """AttnLRP for BERT sequence classification task."""
     14     tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-CoLA")
---> 15     model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-CoLA").to(torch.device("cuda"))
     16     model.eval()
     17 

/opt/conda/envs/rapids/lib/python3.10/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3083 
   3084         with ContextManagers(init_contexts):
-> 3085             model = cls(config, *model_args, **model_kwargs)
   3086 
   3087         # Check first if we are `from_pt`

/shared_data0/weiqiuy/github/LRP-eXplains-Transformers/lxt/models/bert.py in __init__(self, config)
   1526         self.config = config
   1527 
-> 1528         self.bert = BertModel(config)
   1529         classifier_dropout = (
   1530             config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob

/shared_data0/weiqiuy/github/LRP-eXplains-Transformers/lxt/models/bert.py in __init__(self, config, add_pooling_layer)
    897         self.config = config
    898 
--> 899         self.embeddings = BertEmbeddings(config)
    900         self.encoder = BertEncoder(config)
    901 

/shared_data0/weiqiuy/github/LRP-eXplains-Transformers/lxt/models/bert.py in __init__(self, config)
    202         # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
    203         # any TensorFlow checkpoint file
--> 204         self.LayerNorm = lm.LayerNormEpsilon(config.hidden_size, eps=config.layer_norm_eps)
    205         self.dropout = nn.Dropout(config.hidden_dropout_prob)
    206         # position_ids (1, len position emb) is contiguous in memory and exported when serialized

/shared_data0/weiqiuy/github/LRP-eXplains-Transformers/lxt/modules.py in __init__(self, normalized_shape, eps, elementwise_affine, bias, device, dtype)
     45 
     46     def __init__(self, normalized_shape, eps: float = 0.00001, elementwise_affine: bool = True, bias: bool = True, device=None, dtype=None):
---> 47         super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
     48 
     49     def forward(self, x):

TypeError: LayerNorm.__init__() takes from 2 to 6 positional arguments but 7 were given

Also I made one small change to the import line from lxt.models.bert import attnlrp, BertForSequenceClassification to match the actual directory.

dvdblk commented 2 months ago

fyi the example from @rachtibat works for me (with the import modification that @fallcat mentions), thank you.

Regarding LLaMA: I have downgraded my transformers version to 4.34.1 to match the reference implementation for tinyllama and I can confirm that attnlrp relevance scores for the 17 class classification task work with LlamaForSequenceClassification on LLaMA2 models.

However, I would still like to get it working with the latest transformers version so that it also works with LLaMA3 models. What would you @rachtibat recommend as the best way to keep the library updated for future transformers versions while also being backwards compatible? Updating each lxt/models/<model>.py for a specific transformers version seems a bit tedious :(

EDIT: Upon further inspection I think my issue with using a finetuned llama-3 (as opposed to a finetuned tinyllama) might be related to using LoRA for training, which adds its own Linear nodes to the model thus skipping the attnlrp rules.

fallcat commented 2 months ago

@dvdblk thank you for the confirmation. I just realized that my pytorch version was 2.0.1 while now I updated to 2.3.0 and it works.

dvdblk commented 2 months ago

closing as the original issue is solved 👍

rachtibat commented 2 months ago

Thanks @fallcat! Sorry @dvdblk I copied my script too fast into GitHub (: I edited my initial comment here, so now my BERT example works correctly.

Regarding @dvdblk's question: I actually got Llama 3 for LlamaForCausalLM working with transformers==4.36.1 and torch==2.1.0. Yes, the problem is that huggingface is constantly changing the modeling_*.py files.

I am working on a way to trace all model operations with torch.fx and replace them on the fly, explained in this issue https://github.com/rachtibat/LRP-eXplains-Transformers/issues/3#issuecomment-2162778579. I am quite optimistic that it will work out. The challenge right now is to enable activation/gradient checkpointing for torch.fx traced models. If you have any idea or suggestion, how we could automatically replace autograd functions or stay updated with huggingface, feel free to chat (: