inseq-team / inseq

Interpretability for sequence generation models πŸ› πŸ”
https://inseq.org
Apache License 2.0
378 stars 36 forks source link

Show only part of the input when calling show() method #281

Closed RibinMTC closed 4 months ago

RibinMTC commented 5 months ago

Question

Hi, thank you for the awesome library. How can I show only part of my input prompt in the heatmap? For example, my prompt has the structure: "instruction - {_inputtext}", I want to ignore the scores for the instruction but only show the heatmap for the _inputtext. I have used the following code:

model_path = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True, device_map="auto")
attrib_model = inseq.load_model(
        model=model,
        attribution_method="attention"
    )
input_prompt = """Instruction: Summarize this article.
Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic."""

full_output_prompt = input_prompt + "Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."

out = attrib_model.attribute(
    input_texts=input_prompt,
    generated_texts=full_output_prompt
)
subw_sqa_agg = out.aggregate(SubwordAggregator, special_chars="▁").aggregate()
subw_viz = subw_sqa_agg.show(return_html=True, do_aggregation=False)
gsarti commented 5 months ago

Hi @RibinMTC,

Thanks for reaching out! This was a notable missing option in the current version of Inseq, so I added a PR (#282) to introduce a SliceAggregator class to handle this behavior. You can try it out (pip install git+https://github.com/inseq-team/inseq.git@viz-slice) and let me know whether it addresses your concern:

import inseq

attrib_model = inseq.load_model("google/gemma-2b", "attention")
input_prompt = """Instruction: Summarize this article.
Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.
Summary:"""

full_output_prompt = input_prompt + " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."
out = attrib_model.attribute(input_prompt, full_output_prompt)[0]

# Slice the summary -> aggregate subwords -> default attention aggregation (mean head, mean layer) + show
out[13:71].aggregate("subwords").show()
RibinMTC commented 4 months ago

This was exactly what I was looking for, thank you very much :)