ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
358 stars 34 forks source link

NNsight should fail when an Unset proxy from a previous trace / future computation is used in a patching experiment #143

Open Butanium opened 3 months ago

Butanium commented 3 months ago

This kind of silent failure can make nnsight very hard to debug:

import torch as th
from nnsight import LanguageModel
nn_model = LanguageModel("gpt2", device_map="cpu")

# The patching fails silently because hidden is not set
with nn_model.trace("a"):
    hidden = nn_model.transformer.h[0].output
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = hidden
    corrupted_logits = nn_model.lm_head.output.save()

# The patching will work
with nn_model.trace("a"):
    hidden = nn_model.transformer.h[0].output.save()
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = hidden
    corrupted_logits2 = nn_model.lm_head.output.save()

# The patching will fail silently because h[10].output has not been computed when h[0] is computed
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = nn_model.transformer.h[10].output
    corrupted_logits3 = nn_model.lm_head.output.save()

with nn_model.trace("b"):
    clean_logits = nn_model.lm_head.output.save()
assert not th.allclose(clean_logits, corrupted_logits2), "this assert pass"
assert not th.allclose(clean_logits, corrupted_logits3), "this assert fails"
assert not th.allclose(clean_logits, corrupted_logits), "this assert fails"