Here is me specifying I want an intervention to apply to all iterations using a global .all():
from nnsight import LanguageModel
from nnsight.intervention import InterventionProtocol
from nnsight import list
import torch
model = LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="auto", torch_dtype=torch.bfloat16)
from nnsight import list
with model.generate("hello world", max_new_tokens=5):
values = list().save()
model.all()
values.append(model.lm_head.output)
print(len(values.value)) # Prints 5
Now if I only wanted to do some interventions every iteration, I can use it like a context manager:
with model.generate("hello world", max_new_tokens=5):
values = list().save()
other_values = list().save()
with model.all():
values.append(model.lm_head.output)
other_values.append(model.lm_head.output)
print(len(values.value)) # Prints 5
print(len(other_values.value)) # Prints 1
.all() is an alias for .iter[:]. Yes thats right, you can specify a specific iteration with an int, multiple iterations with a list of ints, or a range using a slice:
with model.generate("hello world", max_new_tokens=5):
values = list().save()
other_values = list().save()
with model.iter[2:4]:
values.append(model.lm_head.output)
other_values.append(model.lm_head.output)
print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1
with model.generate("hello world", max_new_tokens=5):
values = list().save()
other_values = list().save()
with model.iter[[0,1,4]]:
values.append(model.lm_head.output)
other_values.append(model.lm_head.output)
print(len(values.value)) # Prints 3
print(len(other_values.value)) # Prints 1
This also works inline:
with model.generate("hello world", max_new_tokens=5):
values = list().save()
other_values = list().save()
values.append(model.lm_head.iter[2:4].output)
other_values.append(model.lm_head.output)
print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1
New Feature ! @AdamBelfki3 @cadentj
New paradigm to specify module iterations.
Here is me specifying I want an intervention to apply to all iterations using a global .all():
Now if I only wanted to do some interventions every iteration, I can use it like a context manager:
.all() is an alias for .iter[:]. Yes thats right, you can specify a specific iteration with an int, multiple iterations with a list of ints, or a range using a slice:
This also works inline:
same thing for .all() applies to .next()