This kind of silent failure can make nnsight very hard to debug:
import torch as th
from nnsight import LanguageModel
nn_model = LanguageModel("gpt2", device_map="cpu")
# The patching fails silently because hidden is not set
with nn_model.trace("a"):
hidden = nn_model.transformer.h[0].output
with nn_model.trace("b"):
nn_model.transformer.h[0].output = hidden
corrupted_logits = nn_model.lm_head.output.save()
# The patching will work
with nn_model.trace("a"):
hidden = nn_model.transformer.h[0].output.save()
with nn_model.trace("b"):
nn_model.transformer.h[0].output = hidden
corrupted_logits2 = nn_model.lm_head.output.save()
# The patching will fail silently because h[10].output has not been computed when h[0] is computed
with nn_model.trace("b"):
nn_model.transformer.h[0].output = nn_model.transformer.h[10].output
corrupted_logits3 = nn_model.lm_head.output.save()
with nn_model.trace("b"):
clean_logits = nn_model.lm_head.output.save()
assert not th.allclose(clean_logits, corrupted_logits2), "this assert pass"
assert not th.allclose(clean_logits, corrupted_logits3), "this assert fails"
assert not th.allclose(clean_logits, corrupted_logits), "this assert fails"
This kind of silent failure can make nnsight very hard to debug: