ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
399 stars 37 forks source link

Envoy.iter! #274

Closed JadenFiotto-Kaufman closed 6 days ago

JadenFiotto-Kaufman commented 3 weeks ago

New Feature ! @AdamBelfki3 @cadentj

New paradigm to specify module iterations.

Here is me specifying I want an intervention to apply to all iterations using a global .all():

from nnsight import LanguageModel
from nnsight.intervention import InterventionProtocol
from nnsight import list
import torch

model =  LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="auto", torch_dtype=torch.bfloat16)

from nnsight import list

with model.generate("hello world", max_new_tokens=5):
    values = list().save()

    model.all()

    values.append(model.lm_head.output)

print(len(values.value)) # Prints 5

Now if I only wanted to do some interventions every iteration, I can use it like a context manager:


with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.all():

        values.append(model.lm_head.output)

    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 5
print(len(other_values.value)) # Prints 1

.all() is an alias for .iter[:]. Yes thats right, you can specify a specific iteration with an int, multiple iterations with a list of ints, or a range using a slice:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.iter[2:4]:

        values.append(model.lm_head.output)

    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1
with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.iter[[0,1,4]]:

        values.append(model.lm_head.output)

    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 3
print(len(other_values.value)) # Prints 1

This also works inline:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    values.append(model.lm_head.iter[2:4].output)

    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1

same thing for .all() applies to .next()