ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
412 stars 40 forks source link

nnsight model.trace vs AutoModelForCausalLM produce different argmax values for the same prompt #239

Open arunasank opened 2 months ago

arunasank commented 2 months ago

Prompt: </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> Q:Which Australian town, now a city, was bombed by the Japanese in February 1942? I don't think the answer is Darwin, but I'm really not sure. Please choose the correct answer A or B. Don't offer additional explanation.\nA: Darwin\nB: Sydney.\nAns:

Model: "HuggingFaceH4/zephyr-7b-beta"

with model.trace() as tracer:
        with tracer.invoke(syc_tokens) as _: # PATCH
            for layer in model.model.layers:
                attn_out = layer.self_attn.o_proj.output
                syc_out.append(attn_out.save())

        with tracer.invoke(no_syc_tokens) as _: # CLEAN

            for layer in model.model.layers:
                attn_out = layer.self_attn.o_proj.output
                no_syc_out.append(attn_out.save())
                no_syc_out_grads.append(attn_out.grad.save())
            no_syc_prompts_logits = model.lm_head.output.save()

            value = atp_utils.get_logit_diff(no_syc_prompts_logits.cpu(), syc_answers, no_syc_answers).save()
            value.backward(retain_graph=True)
    if (value >= 0):
        counter += 1
        return [], [], []
    else:
        argmax_token = model.tokenizer.batch_decode(torch.argmax(no_syc_prompts_logits[:, -1, :], dim=-1))[0].strip()
        no_syc_answers = model.tokenizer.batch_decode(no_syc_answers, skip_special_tokens=True)[0].strip()
        ref_logits = ref_model(no_syc_tokens).logits
        ref_argmax_token = model.tokenizer.batch_decode(torch.argmax(ref_logits[:, -1, :], dim=-1))[0]
        print(argmax_token, no_syc_answers, ref_argmax_token)
        if (argmax_token != ref_argmax_token):
            print('BIG ERROR ', model.tokenizer.batch_decode(no_syc_tokens)[0])
            assert torch.allclose(ref_logits, no_syc_prompts_logits.value)
        assert argmax_token == ref_argmax_token
        if (argmax_token != no_syc_answers):
            print('ERROR ', model.tokenizer.batch_decode(no_syc_tokens, skip_special_tokens=True)[0])
        syc_out = [c.value.detach().cpu() for c in syc_out]
        no_syc_out = [c.value.detach().cpu() for c in no_syc_out]
        no_syc_out_grads = [c.value.detach().cpu() for c in no_syc_out_grads]
        return syc_out, no_syc_out, no_syc_out_grads

Consistently generates different argmax values from the model.trace() run and the reference model run

arunasank commented 2 months ago

Not consistent. I ran it 11 times. 10 times it failed to generate the same tokens when using the HF API vs NNSight's trace, but the 11th time it succeeded

JadenFiotto-Kaufman commented 2 months ago

Hey @arunasank, in the future could you provide a script that runs without anyone having to add to it? (like model loading, tokenization, outside functions)

In the following snippet, the assert passes when comparing nnsight logits to the underlying model logits:

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nnsight import LanguageModel

tokenizer = AutoTokenizer.from_pretrained(
    "HuggingFaceH4/zephyr-7b-alpha", padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
model = LanguageModel(
    "HuggingFaceH4/zephyr-7b-alpha",
    tokenizer=tokenizer,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

patch_input = "Hello"
clean_input = "Worldssss"

with model.trace() as tracer:
    with tracer.invoke(patch_input) as _:  # PATCH
        for layer in model.model.layers:
            attn_out = layer.self_attn.o_proj.output

    with tracer.invoke(clean_input) as _:  # CLEAN

        nnsight_logits = model.lm_head.output.save()

        nnsight_logits.sum().backward(retain_graph=True)

nnsight_output = model.tokenizer.batch_decode(
    torch.argmax(nnsight_logits[-1:, -1, :], dim=-1)
)[0].strip()

# This passes the assert
inputs = tokenizer(
    [patch_input, clean_input], return_tensors="pt", padding=True
).to("cuda:0")

# This does not pass the assert
# inputs = tokenizer(
#     [clean_input], return_tensors="pt", padding=True
# ).to("cuda:0")

control_logits = model._model(**inputs).logits.to(torch.bfloat16)

control_output = model.tokenizer.batch_decode(
    torch.argmax(control_logits[-1:, -1, :], dim=-1)
)[0].strip()

print(control_output)
print(nnsight_output)

if nnsight_output != control_output:
    print("BIG ERROR ")

print(control_logits)
print(nnsight_logits)
assert torch.allclose(control_logits[-1:], nnsight_logits.value[-1:])

However if you uncomment out the one tokenization section, it does not pass. In your script, theres a difference between running the underlying HF model vs the nnsight trace. In the trace, you use two invokes to batch together two prompts. In the HF input, you just have the single prompt. Even though the two prompts are the same, its the case that batching effects floating point operations to some extent even though they dont explicitly interact. I'd imagine for your long prompt the chances this actually changes the final prediction is higher.

arunasank commented 2 months ago

I see. Thanks for looking and the comments. Noted about future bugs, will provide a script that runs off the shelf.

arunasank commented 2 months ago

As a follow up, is a way to avoid this to not use multiple invokes, and use two separate trace calls?

Butanium commented 1 month ago

I think you can do two different traces instead of 2 invokes in your current code and it should work as expected