inseq-team / inseq

Interpretability for sequence generation models 🐛 🔍
https://inseq.org
Apache License 2.0
378 stars 36 forks source link

Forced generations with decoder-only models must start with input texts #278

Closed acDante closed 4 months ago

acDante commented 5 months ago

Question

Hi @gsarti , I got this issue when I used tokenizer.apply_chat_template() to build the prompt before computing attribution. How to use the model.attribute() function when the input prompt starts with some special tokens?

Here is my code snippet:

import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import inseq

# Load model and test data
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

test_data = load_dataset("xsum", split="test")
doc = test_data[1]['document']
input_prompt = f"Summarise the document below: {doc}"
messages = [{
    "role": "user",
    "content": input_prompt
}]

prompt = tokenizer.apply_chat_template(messages,
                                       tokenize=False,
                                       add_generation_prompt=True)

attr_type="attention"
inseq_model = inseq.load_model(model, attr_type)
out = inseq_model.attribute(prompt, generation_args={"do_sample": False, "max_new_tokens": 100, "skip_special_tokens": False})
out.show()

The stack trace:

AssertionError                            Traceback (most recent call last)
[<ipython-input-1-52b094a48219>](https://localhost:8080/#) in <cell line: 29>()
     27 attr_type="attention"
     28 inseq_model = inseq.load_model(model, attr_type)
---> 29 out = inseq_model.attribute(prompt, generation_args={"do_sample": False, "max_new_tokens": 100, "skip_special_tokens": False})
     30 out.show()

[/usr/local/lib/python3.10/dist-packages/inseq/models/attribution_model.py](https://localhost:8080/#) in attribute(self, input_texts, generated_texts, method, override_default_attribution, attr_pos_start, attr_pos_end, show_progress, pretty_progress, output_step_attributions, attribute_target, step_scores, include_eos_baseline, attributed_fn, device, batch_size, generate_from_target_prefix, skip_special_tokens, generation_args, **kwargs)
    445         logger.debug(f"reference_texts={generated_texts}")
    446         if not self.is_encoder_decoder:
--> 447             assert all(
    448                 generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts))
    449             ), "Forced generations with decoder-only models must start with the input texts."

AssertionError: Forced generations with decoder-only models must start with the input texts.

Additional context

I checked this thread: https://github.com/inseq-team/inseq/issues/271 and tried to add skip_special_tokens=False in inseq_model.attribute(), but this issue still occurs.

Checklist

gsarti commented 5 months ago

Hi @acDante,

The problem was due to the apply_chat_template that was adding the BOS, and even though skip_special_tokens was specified this wasn't ignored in model.attribute. You can check out the branch fix-multidevice as shown in my last comment on #276 and it should work now!