ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
356 stars 34 forks source link

Non inplace model editing #206

Closed AdamBelfki3 closed 2 weeks ago

AdamBelfki3 commented 2 weeks ago

Improvement! Edits defined within an edit context don't affect the base model by default. This means that the edits are not in-place.

from nnsight import LanguageModel

gpt2 = LanguageModel("openai-community/gpt2", device_map="auto")

with gpt2.edit("") as gpt2_edited:
        gpt2.transformer.h[1].output[0][:] = 0

with gpt2.trace("Hello World"):
    l1_out = gpt2.transformer.h[1].output[0].save()

with gpt2_edited.trace("Hello World"):
    l1_out_edited = gpt2_edited.transformer.h[1].output[0].save()

print("L1 - Out: ", l1_out)
print("L1 - Out (edited): ", l1_out_edited)
"""
L1 - Out: tensor([[[ 1.0967, -1.8792,  1.0121,  ..., -1.0442, -0.5490, -1.1239],
         [-0.2300,  0.5204,  0.4756,  ..., -1.6795, -0.5285,  0.4190]]],
       device='mps:0', grad_fn=<AddBackward0>)

L1 - Out (edited): tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='mps:0',
       grad_fn=<CopySlices>)
"""

If you wish to make in-place edits, simply set inplace=True:

with model.edit(inplace=True):
    pass

New Feature! If you wish to access the Tracer context responsible for the edits, simply set return_context=True:

with model.edit(return_context=True) as (edited_model, editor):
    pass

Note that the return_context flag is only relevant when creating an edit context.

New feature! You can now clear edits from a model. Particularly useful if you run in-place edits on a model.

from nnsight import LanguageModel

gpt2 = LanguageModel("openai-community/gpt2", device_map="auto")

with gpt2.edit() as gpt2:
    gpt2.transformer.h[1].output[0][:] = 0

gpt2.clear_edits()

with gpt2.trace("Hello World"):
    l1_out_unedited = gpt2.transformer.h[1].output[0].save()

print("L1 - Out (unedited): ", l1_out_unedited)
"""
L1 - Out (unedited): tensor([[[ 1.0967, -1.8792,  1.0121,  ..., -1.0442, -0.5490, -1.1239],
         [-0.2300,  0.5204,  0.4756,  ..., -1.6795, -0.5285,  0.4190]]],
       device='mps:0', grad_fn=<AddBackward0>)
"""
mitroitskii commented 2 weeks ago

@AdamBelfki3 does the editor object expose any methods besides .log()? .exit() does not seem to work.