rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
83 stars 8 forks source link

[Desiderata] Captum-like implementation for Inseq compatibility #1

Open gsarti opened 7 months ago

gsarti commented 7 months ago

Hi @rachtibat,

Great job on AttnLRP, your LRP adaptation seems very promising to attribute the behavior of Transformer-based LMs!

Provided you guys are still working on the codebase, I was wondering whether it would be possible to have an implementation that is interoperable with the Captum LRP class. This would allow us to include the method in the inseq library (reference: inseq-team/inseq#122), enabling out-of-the-box support for:

inseq has been already used in a number of publications since its release last year, and having an accessible implementation of AttnLRP there would undoubtedly help to democratize the access to your method. From an implementation perspective, I'm not an LRP connaisseur but my understanding is that for ensuring Captum compatibility it would be enough to specify your custom propagation rules matching the base class provided here.

Let me know your thoughts, and congrats again on the excellent work!

rachtibat commented 7 months ago

Hey @gsarti,

I'm glad that you like our paper! You do great work at inseq!

While LRP substantially outperforms other methods, it has an initial 'set-up' cost i.e. there is currently no implementation in PyTorch that is able to automatically apply the rules to all operations in a PyTorch graph. For instance, in LRP we must apply the epsilon rule on every summation operation. This means, that if we have a line of code such as c = a + b we have to attach our LRP rule to this line of code somehow. In this repository, I am implementing custom PyTorch autograd function. This means that we have to replace the line of code with c = epsilon_sum.apply(a, b) So, the user has to put in some extra effort.

I'm not aware of a way to do this kind of code manipulation/graph manipulation on the fly. I just found this tutorial on torch.fx. Maybe this is the solution?

As a consequence, I'm implementing LRP right now in the style of zennit, but I'm optimistic that we can somehow integrate it into captum for pre-defined model architectures such as Llama 2 etc. Just to run some benchmarks against other methods for instance. (The LRP implementation of captum is not optimal in our usecase because they use hooks and hooks are quite inefficient, but maybe we can agree on a new adaption of their LRP class?).

Best greetings, and thank you again (:

gsarti commented 7 months ago

Thanks for your prompt reply @rachtibat!

I see the issue with setup costs, thanks for clarifying! I had an in-depth look to torch.fx some time ago for inference-time mid-forward interventions (e.g. for the Value Zeroing method we're adding in PR inseq-team/inseq#173), and I also had a chance to chat about it with Captum lead devs at EMNLP. In general, it is very cumbersome and counter-intuitive to perform very targeted interventions, but maybe for replacing all operations of a specific type it can be manageable. I'd be very interested to see if you come up with a solution using torch.fx to make the implementation generalizable!

Would the zennit implementation you have in mind support multi-token generation? In my experience, this is the main limiting factor to applying such techniques to autoregressive LMs (which we address by looping attribution in inseq), especially since people usually want to customize generation parameters à-la-transformers without reinventing the wheel.

rachtibat commented 7 months ago

Hey @gsarti,

awesome, that you already had so much experience with torch.fx. Alright, good to know, maybe it is really manageable with simple operation replacement without doing fancy manipulations.

I'm not quite sure what you mean by multi-token generation, but I try to give you an idea, what is possible if someone wants to explain several tokens at once: Assumed a LLM generated a sequence of N tokens.

I hope this explains it, if it is unclear, you can ask again (:

gsarti commented 7 months ago

Thanks for the response @rachtibat! To clarify, the background to my question was that typically library like Captum provide an interface to streamline the attribution of a single forward output (the first bullet point you describe). However, there is no simple abstraction to automate the "one attribution per generation step" process you describe in the third bullet point (although in the case of Captum, they actually added something akin to this in v0.7). The main reason of inseq existence was precisely to automate this process while enabling the full customization of the model.generate method of 🤗 transformers.

The 2nd approach you mention (the one proposing a "superposition" of 3 attributions) looks very interesting, and I think it's the first time I see this idea! But I have a doubt: this would mean, effectively, taking the output logit of previous tokens (e.g. 2, and 5 in your example) when computing the forward for token N-4 and using it to propagate relevance back into the model. Don't you think this is a bit unnatural to extract rationales, provided only the last token when computing predictions at every generation step? Not sure what information the preceding embeddings would provide in this context. Curious to hear your thoughts!

rachtibat commented 7 months ago

Hey,

afaik transformers are trained with a next token prediction at any output position. If you look at the huggingface implementation of Llama 2 for instance you see that the labels for CrossEntropy are the inputs shifted by one. So the model actually predicts at each output token and not just the last token. Because of the causual masking in attention heads, each output position N can only see the prior N-1 input tokens and does an independent prediction. This is why, we can actually do what I've described in bullet point 2 (:

I already tried it and it is equal to computing the attribution for each output token separately and adding them up or computing a superimposed heatmap at once. This is also due to the fact that LRP is an additive explanatory model i.e. the attribution can be disentangled into several independent relevance flows. We described this phenomenon in this paper: https://www.nature.com/articles/s42256-023-00711-8

So, I think this might be a feature only present in additive explanatory models. I hope it is somewhat clear (:

gsarti commented 7 months ago

This is very interesting, you're right! I was thinking of inference, but it is true that at training time the model does indeed predict a token at every position. The fact that it results in a simple sum of independent relevance flows is definitely an upside of additive models, looking forward to test it out! :)