ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
379 stars 35 forks source link

`lm.input.save()` throws error in the latest version `0.3.6` #276

Open arnab-api opened 2 days ago

arnab-api commented 2 days ago

The code below works perfectly if I just remove lm.input.save(). I can still grab the input to a submodule though. I am currently on nnsight version 0.3.6. This code was working fine on a previous version (I forgot which version I was on before).

lm = LanguageModel(
    model_key="EleutherAI/pythia-410m",
    device_map = "auto",
    dispatch=True
)

prompt = "A quick brown fox jumps over the lazy"
tokens = prepare_input(prompt, tokenizer=lm.tokenizer)

with torch.inference_mode():
    with lm.trace(tokens) as tr:
        resid_out = lm.gpt_neox.layers[3].output[0].save()
        input = lm.input.save()  # <--- grabbing the inputs doesn't work anymore. 
        resid_in = lm.gpt_neox.layers[3].input.save() # <--- works fine!

I am getting this error below.

--------------------------------------------------------------------------
IndexError: Above exception when execution Node: 'getitem_1' in Graph: '139994817901520'
AdamBelfki3 commented 2 days ago

Which prepare_input(...) function are you using?