ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
345 stars 34 forks source link

Make tensor creation functions like torch.zeros traceable via nnsight #140

Closed JadenFiotto-Kaufman closed 3 weeks ago

Butanium commented 3 months ago

Isn't that already supported ? I've run remote code with torch.arange and it worked iirc

JadenFiotto-Kaufman commented 3 months ago

@Butanium Take this example:

from nnsight import LanguageModel
import torch

nn_model = LanguageModel("EleutherAI/gpt-j-6b", device_map="auto")

with nn_model.generate("hello", remote=False, scan=False, validate=False) as tracer:
    shape = nn_model.transformer.h[0].mlp.output.shape

    new_tensor = torch.zeros(shape)
    new_tensor = torch.zeros(shape)
TypeError: zeros(): argument 'size' (position 1) must be tuple of ints, but found element of type LanguageModelProxy at pos 0

The nnsight proxy relies on the torch function checking for __torch_function__ on the nnsight.tracing.Proxy object. Some torch functions dont check for that. I think the rule is if the positional arguments have a tensor, then it checks. In this case, torch.zeros expects a tuple of ints as the shape right, not a Tensor. torch.zeros_like(<Proxy>) works just fine.

I just need patch the functions with a Proxy.proxy_wrapper version of it. However if I remember correctly, this interfered with accelerate as it also patches those functions temporarily.

JadenFiotto-Kaufman commented 3 weeks ago

@JadenFiotto-Kaufman This works in latest release