inseq-team / inseq

Interpretability for sequence generation models 🐛 🔍
https://inseq.org
Apache License 2.0
374 stars 36 forks source link

Issue with zeroing_value attribution and SubwordAggregator() #291

Open rafikg opened 23 hours ago

rafikg commented 23 hours ago

Question

This is MRE:

import inseq
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from inseq.data.aggregator import  SubwordAggregator
import torch
import numpy as np
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
tokenizer.src_lang = 'en'
tokenizer.tgt_lang = 'fr'
model = inseq.load_model(model=model,
                         tokenizer=tokenizer,
                         attribution_method="value_zeroing")
s = time.time()
out = model.attribute(
    input_texts="Life is like a box of chocolates.",
    generated_texts="La vie est comme une boite de chocolats.",
    attribute_target=False,
    show_progress=True,
)
f = time.time()
# out.show() # this is working
agg = SubwordAggregator()
agg_out = agg.aggregate(attr=out[0].sequence_attributions)
agg_out.show(do_aggregation=True) ## Here is the error

Error:

image

value_zeroing: source_length X target_length X n_heads

image

Attention

image
gsarti commented 3 hours ago

Hi @rafikg, thank you for reporting this, good catch! I just opened a PR #292 that should fix this issue with the aggregation step, could you verify if it is working for you if you check out that branch?

rafikg commented 2 hours ago

@gsarti Thanks it is working on my side.

One more question: I annotated some translated sentences by highlighting the error span.

I want to calculate the importance of each translated token to see if tokens with higher importance correspond to error span.

In my above example, I calculate the contribution of each source token to generate each translated token. I am not sure how to leverage this to get the importance of each translated token.