rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
66 stars 7 forks source link

LLaMA Quickstart repro with Inseq and compatibility question #3

Open gsarti opened 3 months ago

gsarti commented 3 months ago

Hey @rachtibat,

Great job with the implementation! I just wanted to report that I managed to reproduce the results of the TinyLLaMA Quickstart demo using Inseq as wrapper for the attribution process:


import inseq
from lxt.models.llama import LlamaForCausalLM, attnlrp

# An even smaller LLaMA to use as an example
model_id = "ahxt/LiteLlama-460M-1T"
prompt = """\
Context: Mount Everest attracts many climbers, including highly experienced mountaineers. There are two main climbing routes, one approaching the summit from the southeast in Nepal (known as the standard route) and the other from the north in Tibet. While not posing substantial technical climbing challenges on the standard route, Everest presents dangers such as altitude sickness, weather, and wind, as well as hazards from avalanches and the Khumbu Icefall. As of November 2022, 310 people have died on Everest. Over 200 bodies remain on the mountain and have not been removed due to the dangerous conditions. The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960. \
Question: How high did they climb in 1922? According to the text, the 1922 expedition reached 8,"""

# Setup model and tokenizer
model = LlamaForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
attnlrp.register(model)

# Get LXT relevance
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)
output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)

max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]
relevance = relevance / relevance.abs().max()

# Get Inseq relevance
model = inseq.load_model(model, "saliency", tokenizer=tokenizer)

inseq_out = model.attribute(
    prompt,
    generation_args={"max_new_tokens": 1},
    abs=False, # By default Captum takes abs(.) of gradients
    attributed_fn="logit" # Override post-softmax probability as default function
)

# Format Inseq scores to match LXT ones
inseq_relevance = inseq_out[0].aggregate("sum", rescale=True, normalize=False).target_attributions # Sum over last dim and apply rescaling based on max value
inseq_relevance = inseq_relevance[:-1, :].squeeze() # Remove last "nan" entry used for plotting and batch dimension
torch.allclose(inseq_relevance, relevance, atol=1e-4) # True

I think this could be useful since Inseq can easily extract scores for an entire generated sequence without having to perform forward passes step by step. It also allows for easy post-processing and visualization of attribution scores, which could make its usage simpler for LXT users!

I only had a question related to the current need to override model implementations to define custom propagation rules. My understanding is that most operations, including in-place ones (as mentioned here) could be overridden via the Composite class. However, I see you still need to perform slight adaptations to the architecture, e.g. defining custom ProjSiluMultiplication and AttentionValueMatmul for LLaMA. How challenging would it be to have a registrable Composite accounting for all operations requiring overriding without having to subclass the original transformers models? Is there really no way out of this?

rachtibat commented 2 months ago

Hey @gsarti,

this is really awesome! Thank you for figuring this out, I would like to add this to the documentation. We just submitted the camera-ready version to ICML 2024, so now I have a little bit more time to spend on this repository again.

I ran your code and noticed that the scaling of the relevances is not perfect for plotting the heatmaps in LaTeX (otherwise it is fine!). If I compute the max and min values of inseq_relevance, at least one value should be 1 or -1. In other words, if you run

# scaling wrong for LaTeX plotting
pdf_heatmap(tokens, inseq_relevance, path='heatmap_inseq.pdf', backend='xelatex')

it looks different to

# scaling  correct
inseq_relevance2 = inseq_relevance / inseq_relevance.abs().max()
pdf_heatmap(tokens, inseq_relevance2, path='heatmap_inseq2.pdf', backend='xelatex')

Otherwise, your example is 100% fine! Do you have an inseq specific command to obtain this kind of scaling, or do I have to scale the relevance with inseq_relevance2 = inseq_relevance / inseq_relevance.abs().max() ?

Regarding your question: The Composite class is able to change ONLY nn.Module(s). This means, that we have to modify all lines of source code containing e.g. + or torch.nn.softmax etc. by replacing them with the LXT drop-in replacements. The reason why I added AttentionValueMatmul is a lilttle bit confusing: I just wanted to compare our AttnLRP implementation against the CP-LRP implementation (Ali et. al, 2022). So, I converted thetoch.matmul operation into an nn.Module so that I can apply different Composites without having to reload a complete new model source code everytime I change the LRP implementation.

I am actually working on a torch.fx implementation that is able to replace these operations automatically. In my smaller tests, it is already working quite well! However, there is a big drawback: torch.fx traced models do not support gradient checkpointing! And this is necessary to run LRP on large LLMs. This is why, I unfortunately still recommend modifying the source code. I plan to add a range of popular models including BERT, GPT and CLIP.

Best (:

gsarti commented 2 months ago

Hi @rachtibat, thanks for your answer! It would be awesome to have the example available in the LXT docs, and I also plan to add it to our Inseq docs soon.

Regarding your question about rescaling, can it be that you're using v0.6? rescale=True only works when installing from main (pip install git+https://github.com/inseq-team/inseq.git@main) and should perform exactly the operation you mention, but I suspect it is silently ignored as kwargs in previous versions.

Nice to hear that you're digging deeper on the torch.fx side! Could you provide a pointer for the issue with torch.fx and gradient checkpointing? I tried the example here with torch.compile, and it worked for me. Is it a problem specific to symbolic_trace?

rachtibat commented 2 months ago

Hey @gsarti,

Thanks, I didn't try out rescale=True and thank you for pointing me to the issue. Unfortunately, I could not solve the issue, it is indeed related to symbolic_trace. I assume the problem is that symbolic_trace removes all checkpoint() operations from the graphs. Here is an issue explaining it but nobody responded.

I set up a small example demonstrating that a traced module does not perform activation checkpointing:

import torch
from torch.utils.checkpoint import checkpoint
import torch.fx as fx

def forward_hook(module, input, output):
    if input[0].requires_grad:
        print(f"track gradient in submodule")
    else:
        print(f"do not track gradient in submodule")

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear2 = torch.nn.Linear(1024, 2048)
        self.linear3 = torch.nn.Linear(2048, 1024)
        self.linear4 = torch.nn.Linear(2048, 1024)

        # apply forward hook
        self.linear3.register_forward_hook(forward_hook)

    def forward(self, x):
        x = self.linear2(x)
        x = self.linear3(x) + self.linear4(x)
        return x

class myModel(torch.nn.Module):
    def __init__(self, grad_checkpoint):
        super().__init__()
        self.linear1 = torch.nn.Linear(1024, 1024)
        self.submodule = SubModule()
        self.checkpoint = grad_checkpoint

    def forward(self, x):
        x = self.linear1(x)
        if self.checkpoint:
            x = checkpoint(self.submodule, x, use_reentrant=True)
        else:
            x = self.submodule(x)
        x = x.sum()
        return x

def run(grad_checkpoint, use_graph_module):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = myModel(grad_checkpoint).to(device)

    x = torch.randn((1, 1024), requires_grad=True, device=device)

    print("### Start forward pass")
    if use_graph_module:
        tracer = fx.Tracer()
        graph = tracer.trace(model)

        gm = fx.GraphModule(model, graph)
        gm.recompile()
        # gm.print_readable()

        out = gm(x)
    else:
        out = model(x)

    print("### Start backward pass")
    out.backward()

    print("gradient shape", x.grad.shape)

if __name__ == "__main__":
    # works as expected
    print("---- no checkpointing ----")
    run(grad_checkpoint=False, use_graph_module=False)
    print("---- no checkpointing & torch.fx ----")
    run(grad_checkpoint=False, use_graph_module=True)
    print("---- enabled checkpointing ----")
    run(grad_checkpoint=True, use_graph_module=False)

    # does not perform activation checkpointing!
    print("---- enabled checkpointing & torch.fx ----")
    run(grad_checkpoint=True, use_graph_module=True)
rachtibat commented 2 months ago

Yes, I just confirmed it by adding this line at the top

import torch
from torch.utils.checkpoint import checkpoint
import torch.fx as fx

# do not trace inside checkpoint()
fx.wrap('checkpoint')

Then, activation checkpointing works, but torch.fx is not tracing the content of the checkpointed layers, which is defeating the purpose why we use torch.fx in the first place.

If I use torch==2.3 and compile() like in your example, it works! So, there might be a way! Maybe it is possible to extract a graph from compile().