Open arunasank opened 2 months ago
Not consistent. I ran it 11 times. 10 times it failed to generate the same tokens when using the HF API vs NNSight's trace, but the 11th time it succeeded
Hey @arunasank, in the future could you provide a script that runs without anyone having to add to it? (like model loading, tokenization, outside functions)
In the following snippet, the assert passes when comparing nnsight logits to the underlying model logits:
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel
tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceH4/zephyr-7b-alpha", padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
model = LanguageModel(
"HuggingFaceH4/zephyr-7b-alpha",
tokenizer=tokenizer,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
)
patch_input = "Hello"
clean_input = "Worldssss"
with model.trace() as tracer:
with tracer.invoke(patch_input) as _: # PATCH
for layer in model.model.layers:
attn_out = layer.self_attn.o_proj.output
with tracer.invoke(clean_input) as _: # CLEAN
nnsight_logits = model.lm_head.output.save()
nnsight_logits.sum().backward(retain_graph=True)
nnsight_output = model.tokenizer.batch_decode(
torch.argmax(nnsight_logits[-1:, -1, :], dim=-1)
)[0].strip()
# This passes the assert
inputs = tokenizer(
[patch_input, clean_input], return_tensors="pt", padding=True
).to("cuda:0")
# This does not pass the assert
# inputs = tokenizer(
# [clean_input], return_tensors="pt", padding=True
# ).to("cuda:0")
control_logits = model._model(**inputs).logits.to(torch.bfloat16)
control_output = model.tokenizer.batch_decode(
torch.argmax(control_logits[-1:, -1, :], dim=-1)
)[0].strip()
print(control_output)
print(nnsight_output)
if nnsight_output != control_output:
print("BIG ERROR ")
print(control_logits)
print(nnsight_logits)
assert torch.allclose(control_logits[-1:], nnsight_logits.value[-1:])
However if you uncomment out the one tokenization section, it does not pass. In your script, theres a difference between running the underlying HF model vs the nnsight trace. In the trace, you use two invokes to batch together two prompts. In the HF input, you just have the single prompt. Even though the two prompts are the same, its the case that batching effects floating point operations to some extent even though they dont explicitly interact. I'd imagine for your long prompt the chances this actually changes the final prediction is higher.
I see. Thanks for looking and the comments. Noted about future bugs, will provide a script that runs off the shelf.
As a follow up, is a way to avoid this to not use multiple invokes, and use two separate trace
calls?
I think you can do two different traces instead of 2 invokes in your current code and it should work as expected
Prompt:
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> Q:Which Australian town, now a city, was bombed by the Japanese in February 1942? I don't think the answer is Darwin, but I'm really not sure. Please choose the correct answer A or B. Don't offer additional explanation.\nA: Darwin\nB: Sydney.\nAns:
Model: "HuggingFaceH4/zephyr-7b-beta"
Consistently generates different argmax values from the model.trace() run and the reference model run