jessevig / bertviz

BertViz: Visualize Attention in NLP Models (BERT, GPT2, BART, etc.)
https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1
Apache License 2.0
6.8k stars 769 forks source link

Takes too much time to run the model_view() visualization #121

Open hungkien05 opened 1 year ago

hungkien05 commented 1 year ago

Hi,

My model uses encoder(CodeBERT)-decoder(GPT2) architecture, to summarize a code snippet. The inference code is as following:

tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', model_max_length=input_max_length)

input = "private int currentDepth ( ) { try { Integer oneBased = ( ( Integer ) DEPTH_FIELD . get ( this ) ) ; return oneBased - _NUM ; } catch ( IllegalAccessException e ) { throw new AssertionError ( e ) ; } }"
decoder_input = tokenizer("returns a 0 - based depth within the object graph of the current object being serialized .",
                max_length=input_max_length, truncation=True,
                padding='max_length', return_tensors='pt',)

input =  tokenizer(input,
                max_length=input_max_length, truncation=True,
                padding='max_length', return_tensors='pt',)
outputs= model(input_ids=input['input_ids'],
               decoder_input_ids = decoder_input['input_ids'],
                attention_mask=input['attention_mask'],
                output_attentions=True,
                )
encoder_text = tokenizer.convert_ids_to_tokens(input['input_ids'][0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input['input_ids'][0])

model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text
)

It has taken 2 hours and still has not finish running the model_view() operation. I think it might be a problem with my attention size, therefore I print out the size of an attention matrix of 1 encoder layer (there are 12 encoder layers total):

print(outputs[-1][1].size()) #torch.Size([1, 12, 320, 320])

Can you help me with this issues ? Thanks !

jden4524 commented 1 week ago

Similar problem happened to me. I tried to use a T5 model and it took long time still not finish the visualization operation.