jalammar / ecco

Explain, analyze, and visualize NLP language models. Ecco creates interactive visualizations directly in Jupyter notebooks explaining the behavior of Transformer-based language models (like GPT2, BERT, RoBERTA, T5, and T0).
https://ecco.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.97k stars 167 forks source link

Why do we need to construct `one-hot` vectors of the input_ids and then multiply by the embeddings, as opposed to applying the embedding directly? #43

Open vgoklani opened 2 years ago

vgoklani commented 2 years ago

Hey there,

Thanks for releasing this library! I was reviewing your lm.py file, and in particular, I was unclear why you were constructing one-hot vectors and multiplying by the embedding matrix, as opposed to simply applying the embedding directly.

See here:

https://github.com/jalammar/ecco/blob/main/src/ecco/lm.py#L118

my approach:

import torch
from transformers import pipeline

text = "We are very happy to show you the 🤗 Transformers library."

classifier = pipeline('sentiment-analysis', model="distilbert-base-uncased-finetuned-sst-2-english")
model = classifier.model
tokenizer = classifier.tokenizer

encoding = tokenizer.encode_plus(
    text,
    return_tensors="pt",
    add_special_tokens=True,
    return_attention_mask=True,
)

inputs_embeds = model.base_model.embeddings(encoding['input_ids'])

assert inputs_embeds.is_leaf is False
inputs_embeds.retain_grad()

logits = model(inputs_embeds=inputs_embeds, attention_mask=encoding['attention_mask']).logits.squeeze(dim=0)
score = logits[logits.argmax()]
score.backward(gradient=None, retain_graph=True)

inputs_embeds__grad = (inputs_embeds.grad * inputs_embeds)[:, 1:-1, :]  # remove CLS and SEP tokens

feature_importance = torch.norm(inputs_embeds__grad, dim=2)
feature_importance_normalized = (feature_importance / torch.sum(feature_importance)).squeeze(dim=0)

attributions = [{tokenizer.convert_ids_to_tokens(input_id.item()): feature_importance_normalized[index].item()} for index, input_id in enumerate(encoding['input_ids'].squeeze(dim=0)[1:-1])]
  1. note how the inputs_embeds are calculated directly.
  2. Also note the scores computed by the logits using your approach don't match mine (or the pipeline).
# my approach
inputs_embeds = model.base_model.embeddings(encoding['input_ids'])

# your approach:
embedding_dim = model.base_model.embeddings.word_embeddings.embedding_dim
num_embeddings = model.base_model.embeddings.word_embeddings.num_embeddings

input_ids__one_hot = torch.nn.functional.one_hot(encoding["input_ids"], num_classes=num_embeddings).float()
input_ids__one_hot.requires_grad_(True)

assert input_ids__one_hot.requires_grad
assert input_ids__one_hot.is_leaf # leaf node!

embedding_matrix = model.base_model.embeddings.word_embeddings.weight
inputs_embeds__yours = torch.matmul(input_ids__one_hot, embedding_matrix)

assert torch.all(inputs_embeds__yours == inputs_embeds).item() == False

this is because the embedding is a sequence with multiple functions:

# model.base_model.embeddings

Embeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

if I instead just apply the word_embedding directly, I then recover your solution:

inputs_embeds__partial = model.base_model.embeddings.word_embeddings(encoding['input_ids'])

assert torch.all(inputs_embeds__partial == inputs_embeds__yours).item() == True
jalammar commented 2 years ago

I don't actually recall the exact reason, I vaguely recall something about an issue retrieving the backpropped gradients for float variables. The details evade me right now unfortunately. Glad you got it to work like this. I was hoping in the future to relegate attribution to the Captum library.