jessevig / bertviz

BertViz: Visualize Attention in NLP Models (BERT, GPT2, BART, etc.)
https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1
Apache License 2.0
6.99k stars 786 forks source link

There is no result or figure output when running model_view of BART. #110

Closed Junpliu closed 1 year ago

Junpliu commented 2 years ago

I didn't see any figure when running the code below. Is there something that I missed? Help me, pls.

from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view

utils.logging.set_verbosity_error()  # Remove line to see warnings

# Initialize tokenizer and model. Be sure to set output_attentions=True.
# Load BART fine-tuned for summarization on CNN/Daily Mail dataset
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# get encoded input vectors
encoder_input_ids = tokenizer(utterances, return_tensors="pt", add_special_tokens=True).input_ids

# create ids of encoded input vectors
decoder_input_ids = tokenizer("Jane made a 9 PM reservation for 6 people tonight at Vegano Resto .", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text
)
Junpliu commented 2 years ago

utterances = ' '.join(["Jane: Hello", "Vegano Resto: Hello, how may I help you today?", "Jane: I would like to make a reservation.", "Jane: For 6 people, tonight around 20:00", "Vegano Resto: Let me just check.", "Vegano Resto: Ah, I'm afraid that there is no room at 20:00.", "Vegano Resto: However, I could offer you a table for six at 18:30 or at 21:00", "Vegano Resto: Would either of those times suit you?", "Jane: Oh dear.", "Jane: Let me just ask my friends.", "Vegano Resto: No problem.", "Jane: 21:00 will be ok.", "Vegano Resto: Perfect. So tonight at 21:00 for six people under your name.", "Jane: great, thank you!"])

Junpliu commented 2 years ago

I ran the code and the program just crashed. However, the attention weight can be shown as expected. image

jessevig commented 2 years ago

Hi @Junpliu, the visualization may fail for longer inputs as you are using in this example. See: https://github.com/jessevig/bertviz#%EF%B8%8F-limitations In a future version I will add a warning message in these cases.

I was able to get this to work with a shorter input as a test, does it work for you?:

from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view

utils.logging.set_verbosity_error()  # Remove line to see warnings

# Initialize tokenizer and model. Be sure to set output_attentions=True.
# Load BART fine-tuned for summarization on CNN/Daily Mail dataset
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# get encoded input vectors
utterances = "test"
encoder_input_ids = tokenizer(utterances, return_tensors="pt", add_special_tokens=True).input_ids

# create ids of encoded input vectors
decoder_input_ids = tokenizer("Jane made a 9 PM reservation for 6 people tonight at Vegano Resto .", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text
)