cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.3k stars 97 forks source link

Error Using LLama-2 with Fine-Tuned LoRA Adapters: Tensor Size Mismatch in apply_rotary_pos_emb Function #147

Open montygole opened 5 months ago

montygole commented 5 months ago

I encountered a runtime error while using the transformers-interpret library with a fine-tuned LLama-2 model that includes LoRA adapters for sequence classification. The error occurs when invoking the SequenceClassificationExplainer and seems related to tensor size mismatches during the rotary positional embedding application.

Traceback (most recent call last):
  File "/home/input_attr_proj/src/input_attr.py", line 32, in <module>
    word_attributions = cls_explainer("Hello")
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers_interpret/explainers/text/sequence_classification.py", line 316, in __call__
    return self._run(text, index, class_name, embedding_type=embedding_type)
File "/home/input_attr_env/lib/python3.10/site-packages/transformers_interpret/explainers/text/sequence_classification.py", line 270, in _run
    self._calculate_attributions(embeddings=embeddings, index=index, class_name=class_name)
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers_interpret/explainers/text/sequence_classification.py", line 226, in _calculate_attributions
    lig = LIGAttributions(
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers_interpret/attributions.py", line 51, in __init__
    self._attributions, self.delta = self.lig.attribute(
  File "/home/input_attr_env/lib/python3.10/site-packages/captum/log/__init__.py", line 42, in wrapper
    return func(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py", line 390, in attribute
    baselines_layer = _forward_layer_eval(
  File "/home/input_attr_env/lib/python3.10/site-packages/captum/_utils/gradient.py", line 182, in _forward_layer_eval
    return _forward_layer_eval_with_neuron_grads(
  File "/home/input_attr_env/lib/python3.10/site-packages/captum/_utils/gradient.py", line 445, in _forward_layer_eval_with_neuron_grads
    saved_layer = _forward_layer_distributed_eval(
  File "/home/input_attr_env/lib/python3.10/site-packages/captum/_utils/gradient.py", line 294, in _forward_layer_distributed_eval
    output = _run_forward(
  File "/home/input_attr_env/lib/python3.10/site-packages/captum/_utils/common.py", line 531, in _run_forward
    output = forward_func(
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers_interpret/explainers/text/sequence_classification.py", line 181, in _forward
    preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers_interpret/explainer.py", line 197, in _get_preds
    preds = self.model(
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1352, in forward
    transformer_outputs = self.model(
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/input_attrlib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 968, in forward
    layer_outputs = decoder_layer(
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 713, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 624, in forward
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  File "/home/input_attr_env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 182, in apply_rotary_pos_emb
    q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 2

Code sample:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer

id2label = {0: "No", 1: "Yes"}
label2id = {"No": 0, "Yes": 1}
model = AutoModelForSequenceClassification.from_pretrained("outputs/2024-04-21/04-27-20/outputs/checkpoint-2564/", device_map='auto',num_labels=2, id2label=id2label, label2id=label2id)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_tioken = tokenizer.eos_token

cls_explainer = SequenceClassificationExplainer(model,tokenizer)
word_attributions = cls_explainer("Hello")
print(word_attributions)

Additional Context:

The error seems to occur in the apply_rotary_pos_emb function, indicating a tensor size mismatch. This might be due to the integration of LoRA adapters with the LLama-2 model. Any help to resolve this issue or guidance on proper compatibility would be greatly appreciated.

nicolas-richet commented 3 months ago

Hi @montygole, did you manage to make it work ? I am also trying to use the transformers-interpret library with a fine-tuned Llama-2 model for sequence classification.

montygole commented 3 months ago

Hi @nicolas-richet . No I didn't get this library to work. Instead I used captum. This is the code I used for layer integrated gradients

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM

from captum.attr import LayerIntegratedGradients                     
from captum.attr import visualization as viz
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-2-13b-hf") #load model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf")  # load tokenizer

# Set pad token as end of sentence token for tokenizer and model
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id 

# Put model in evaluation mode (disables layer dropout so we can get reproducible input attribution)
model.eval()

# Ref token id is the baseline token to produce integrated gradient against (basically an informationless token, this could also be 0 instead of pad_token_id)
ref_token_id = tokenizer.pad_token_id

def predict(input_ids):
    outputs = model(input_ids)
    return outputs.logits.max(1).values

def construct_input_ref_pair(text, ref_token_id):
    input_ids = tokenizer.encode(text, return_tensors="pt")
    print(input_ids)
    ref_input_ids = torch.full_like(input_ids, tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
    print(ref_input_ids)    
    ref_input_ids[:] = ref_token_id
    return input_ids, ref_input_ids

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

# Create instance of LIG object. Use model.model.embed_tokens to take LIG from the embeddings layer
lig = LayerIntegratedGradients(predict, model.model.embed_tokens)

# Prepare visualization of results
viz_records = []

model_input = tokenizer(your_input_sequence, return_tensors="pt")

# Inference for classification model for analysis
with torch.no_grad():
    logits = model(**model_input).logits

predicted_class_id = logits.argmax().item()
output = model.config.id2label[predicted_class_id]
input_ids, ref_input_ids = construct_input_ref_pair(your_input_sequence, ref_token_id)

attributions_start, delta_start = lig.attribute(
    inputs=input_ids,
    baselines=ref_input_ids,
    n_steps=50,
    return_convergence_delta=True,
    attribute_to_layer_input=attribute_to_layer_input
)

attributions_start_sum = summarize_attributions(attributions_start)

vis_record = viz.VisualizationDataRecord(
    attributions_start_sum,
    torch.max(torch.softmax(logits[0], dim=0)),
    predicted_class_id,
    true_class,
    predicted_class_id,
    attributions_start_sum.sum(),       
    tokenizer.convert_ids_to_tokens(input_ids[0].detach().tolist()),
    delta_start
)
viz_records.append(vis_record)
results.append({
    "true_label": row["y"], 
    "pred_label": predicted_class_id, 
    "pred_pros": torch.max(torch.softmax(logits[0], dim=0)),
    "attr_score": attributions_start_sum.sum(),
    "raw_input_ids": tokenizer.convert_ids_to_tokens(input_ids[0].detach().tolist()),
    "word_attrs": attributions_start_sum,
    "convergence_score": delta_start
})

attribution_viz = viz.visualize_text(viz_records)

html_str = attribution_viz.data
with open("attr_vis.html", "w") as file:
    file.write(html_str)
nicolas-richet commented 3 months ago

Hi @montygole, Thank you for the code! I was able to run it but it seems that the attribution score of the first 'begin_of_text' token is much higher than the rest. Did you know of/had this problem ? I tried to set the attention mask of the first token to 0 but im not sure this is a correct approach.

montygole commented 3 months ago

Hey @nicolas-richet. I actually had the same problem. I experimented with different n_steps (1 vs 50 vs 100 vs 200) values as well as removing the BOS tokens and EOS sequences but the attributions remained unintuitive (either attributing classifications to newline characters, the BOS token, and/or the EOS delimiter I used or outputting unpredictable attribution scores) and difficult to interpret.

I think it makes sense that it would attribute said seemingly meaningless tokens and sequences because they are present in each training sequence and are important to make the model process inputs properly.

Anyways, I have since used the shap library for my input attribution which gives me much more intuitive results. I'd recommend checking it out! Here's a classification example with transformers. https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/sentiment_analysis/Positive%20vs.%20Negative%20Sentiment%20Classification.html

Let me know how it goes! Good luck 😃