rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
66 stars 7 forks source link

How can i get each layer's lrp score? #8

Closed Patrick-Ni closed 2 weeks ago

Patrick-Ni commented 1 month ago

Hey @rachtibat : Thank you for your excellent work. I have one last question. I want to understand the contribution of the hidden states from each layer to the final result, not just the input embeddings. I tried to understand the code and the paper, and then made the following modifications:

input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)
outputs = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False, output_hidden_states=True)
# print(input_embeds.requires_grad_())
# print(outputs.hidden_states[-3].requires_grad_())
logits = outputs.logits
max_logits, max_indices = torch.max(logits[0, -1, :], dim=-1)
max_logits.backward(max_logits)
relevance = outputs.hidden_states[-3].grad.float().sum(-1).cpu()[0]

But this does not work. I would like to ask how can I achieve this?

rachtibat commented 1 month ago

Hey,

you could look at this issue https://github.com/rachtibat/LRP-eXplains-Transformers/issues/2.

You can't simply take the gradients at the hidden_states, because PyTorch does not save gradients on them. You must either activate gradients for these tensors by modifying the model source code writing at the position you want to record hidden_states.retain_grad() or use backward hooks.

Keep in mind, that LXT does nothing crazy! It is as if you would work with a normal PyTorch model. All tools of PyTorch work like always.

I would recommend that you read more about backward hooks and the retain_grad attribute.

Hope it helps

rachtibat commented 2 weeks ago

added example for obtaining relevance at the residual stream at https://lxt.readthedocs.io/en/latest/feature-visualization.html#latent-feature-attribution