ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
356 stars 34 forks source link

Fix disconnect between Iterable values and their Node dependencies #207

Closed AdamBelfki3 closed 2 weeks ago

AdamBelfki3 commented 2 weeks ago

Motivation

The creation of an Iterable data type containing Nodes cannot be traced as an operation on the InterventionProxy when added to the Iterable, which results in a mismatch between the arguments and dependencies of the Node Proxy in question.

This can be illustrated with the following example from the previous implementation where we try to make a list from the output of a layer.

from collections import OrderedDict
from nnsight import NNsight
import nnsight
import torch

input_size = 5
hidden_dims = 10
output_size = 2

torch.manual_seed(423)

net = torch.nn.Sequential(
    OrderedDict(
        [
            ("layer1", torch.nn.Linear(input_size, hidden_dims)),
            ("layer2", torch.nn.Linear(hidden_dims, hidden_dims)),
            ("layer3", torch.nn.Linear(hidden_dims, output_size))
        ]
    )
).requires_grad_(False)

input = torch.rand((1, input_size))
input_2 = torch.rand((1, input_size))

model = NNsight(net)

with model.trace(input) as tracer:
    l = nnsight.list([model.layer1.output])
    l = l.save()

    tracer.vis(title="Tracer")

Tracer

Here the list value argument does not have a link to the InterventionProtocol but rather only the necessary reference to the Node inside of it.

So how do we solve this?

Implementation

It appears the cleanest way to capture visually the dependency relationship in the case described above would be to create a sort of edge between the Node and the Iterable structure that references it, which will in turn showcase the dependency relationship with it's listener node.

Therefore, we loop over all the Node elements of an Iterable and visualize them recursively. Then, we create a gray dashed edge with the iIerable structure directly. And the end result looks like this:

Tracer