Open montygole opened 5 months ago
Hi @montygole, did you manage to make it work ? I am also trying to use the transformers-interpret library with a fine-tuned Llama-2 model for sequence classification.
Hi @nicolas-richet . No I didn't get this library to work. Instead I used captum. This is the code I used for layer integrated gradients
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-2-13b-hf") #load model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf") # load tokenizer
# Set pad token as end of sentence token for tokenizer and model
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Put model in evaluation mode (disables layer dropout so we can get reproducible input attribution)
model.eval()
# Ref token id is the baseline token to produce integrated gradient against (basically an informationless token, this could also be 0 instead of pad_token_id)
ref_token_id = tokenizer.pad_token_id
def predict(input_ids):
outputs = model(input_ids)
return outputs.logits.max(1).values
def construct_input_ref_pair(text, ref_token_id):
input_ids = tokenizer.encode(text, return_tensors="pt")
print(input_ids)
ref_input_ids = torch.full_like(input_ids, tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
print(ref_input_ids)
ref_input_ids[:] = ref_token_id
return input_ids, ref_input_ids
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
# Create instance of LIG object. Use model.model.embed_tokens to take LIG from the embeddings layer
lig = LayerIntegratedGradients(predict, model.model.embed_tokens)
# Prepare visualization of results
viz_records = []
model_input = tokenizer(your_input_sequence, return_tensors="pt")
# Inference for classification model for analysis
with torch.no_grad():
logits = model(**model_input).logits
predicted_class_id = logits.argmax().item()
output = model.config.id2label[predicted_class_id]
input_ids, ref_input_ids = construct_input_ref_pair(your_input_sequence, ref_token_id)
attributions_start, delta_start = lig.attribute(
inputs=input_ids,
baselines=ref_input_ids,
n_steps=50,
return_convergence_delta=True,
attribute_to_layer_input=attribute_to_layer_input
)
attributions_start_sum = summarize_attributions(attributions_start)
vis_record = viz.VisualizationDataRecord(
attributions_start_sum,
torch.max(torch.softmax(logits[0], dim=0)),
predicted_class_id,
true_class,
predicted_class_id,
attributions_start_sum.sum(),
tokenizer.convert_ids_to_tokens(input_ids[0].detach().tolist()),
delta_start
)
viz_records.append(vis_record)
results.append({
"true_label": row["y"],
"pred_label": predicted_class_id,
"pred_pros": torch.max(torch.softmax(logits[0], dim=0)),
"attr_score": attributions_start_sum.sum(),
"raw_input_ids": tokenizer.convert_ids_to_tokens(input_ids[0].detach().tolist()),
"word_attrs": attributions_start_sum,
"convergence_score": delta_start
})
attribution_viz = viz.visualize_text(viz_records)
html_str = attribution_viz.data
with open("attr_vis.html", "w") as file:
file.write(html_str)
Hi @montygole, Thank you for the code! I was able to run it but it seems that the attribution score of the first 'begin_of_text' token is much higher than the rest. Did you know of/had this problem ? I tried to set the attention mask of the first token to 0 but im not sure this is a correct approach.
Hey @nicolas-richet. I actually had the same problem. I experimented with different n_steps
(1 vs 50 vs 100 vs 200) values as well as removing the BOS tokens and EOS sequences but the attributions remained unintuitive (either attributing classifications to newline characters, the BOS token, and/or the EOS delimiter I used or outputting unpredictable attribution scores) and difficult to interpret.
I think it makes sense that it would attribute said seemingly meaningless tokens and sequences because they are present in each training sequence and are important to make the model process inputs properly.
Anyways, I have since used the shap
library for my input attribution which gives me much more intuitive results. I'd recommend checking it out! Here's a classification example with transformers
. https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/sentiment_analysis/Positive%20vs.%20Negative%20Sentiment%20Classification.html
Let me know how it goes! Good luck 😃
I encountered a runtime error while using the transformers-interpret library with a fine-tuned LLama-2 model that includes LoRA adapters for sequence classification. The error occurs when invoking the SequenceClassificationExplainer and seems related to tensor size mismatches during the rotary positional embedding application.
Code sample:
Additional Context:
The error seems to occur in the apply_rotary_pos_emb function, indicating a tensor size mismatch. This might be due to the integration of LoRA adapters with the LLama-2 model. Any help to resolve this issue or guidance on proper compatibility would be greatly appreciated.