ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
400 stars 37 forks source link

Cannot use proxy values from previous runs with `remote=True` #104

Closed tvhong closed 7 months ago

tvhong commented 7 months ago

Description

I noticed a discrepancy between remote=False and remote=True behavior when re-using a proxy from a previous run.

When remote=False, I can reuse the previous proxy in a previous tracer/generator context, but it's not possible with remote=True.

Expected Behavior

The behavior should be the same whether remote is True or False.

Reproduction Steps

from nnsight import CONFIG, LanguageModel

CONFIG.set_default_api_key("<your-api-key>")
model = LanguageModel('openai-community/gpt2-xl')

def run(remote: bool):    
    with model.generate("a dog is a dog, a cat is", max_new_tokens=4, remote=remote):
        embedding = model.transformer.wte.output.save()
        output1 = model.generator.output.save()

    print(embedding.shape)
    print("All token ids: ", output1)
    print("All prediction: ", model.tokenizer.batch_decode(output1))

    tokens_cnt = embedding.shape[1]
    stub_prompt = " ".join("_" * tokens_cnt)
    with model.generate(stub_prompt, max_new_tokens=4, remote=remote):
        model.transformer.wte.output = embedding
        output2 = model.generator.output.save()

    print("All token ids: ", output2)
    print("All prediction: ", model.tokenizer.batch_decode(output2))

run(False) output:

torch.Size([1, 9, 1600])
All token ids:  tensor([[  64, 3290,  318,  257, 3290,   11,  257, 3797,  318,  257, 3797,   11,
          290]])
All prediction:  ['a dog is a dog, a cat is a cat, and']

All token ids:  tensor([[  62, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808,  257, 3797,   11,
          290]])
All prediction:  ['_ _ _ _ _ _ _ _ _ a cat, and']

run(True) output:

torch.Size([1, 9, 1600])
All token ids:  tensor([[  64, 3290,  318,  257, 3290,   11,  257, 3797,  318,  257, 3797,   11,
          290]])
All prediction:  ['a dog is a dog, a cat is a cat, and']

All token ids:  tensor([[  62, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808,
         4808]])
All prediction:  ['_ _ _ _ _ _ _ _ _ _ _ _ _']
tvhong commented 7 months ago

I tried to fix this with https://github.com/ndif-team/nnsight/pull/105 . Please let me know if that seems reasonable.