The creation of an Iterable data type containing Nodes cannot be traced as an operation on the InterventionProxy when added to the Iterable, which results in a mismatch between the arguments and dependencies of the Node Proxy in question.
This can be illustrated with the following example from the previous implementation where we try to make a list from the output of a layer.
from collections import OrderedDict
from nnsight import NNsight
import nnsight
import torch
input_size = 5
hidden_dims = 10
output_size = 2
torch.manual_seed(423)
net = torch.nn.Sequential(
OrderedDict(
[
("layer1", torch.nn.Linear(input_size, hidden_dims)),
("layer2", torch.nn.Linear(hidden_dims, hidden_dims)),
("layer3", torch.nn.Linear(hidden_dims, output_size))
]
)
).requires_grad_(False)
input = torch.rand((1, input_size))
input_2 = torch.rand((1, input_size))
model = NNsight(net)
with model.trace(input) as tracer:
l = nnsight.list([model.layer1.output])
l = l.save()
tracer.vis(title="Tracer")
Here the list value argument does not have a link to the InterventionProtocol but rather only the necessary reference to the Node inside of it.
So how do we solve this?
Implementation
It appears the cleanest way to capture visually the dependency relationship in the case described above would be to create a sort of edge between the Node and the Iterable structure that references it, which will in turn showcase the dependency relationship with it's listener node.
Therefore, we loop over all the Node elements of an Iterable and visualize them recursively. Then, we create a gray dashed edge with the iIerable structure directly. And the end result looks like this:
Motivation
The creation of an
Iterable
data type containing Nodes cannot be traced as an operation on theInterventionProxy
when added to theIterable
, which results in a mismatch between the arguments and dependencies of the Node Proxy in question.This can be illustrated with the following example from the previous implementation where we try to make a list from the output of a layer.
Here the list value argument does not have a link to the InterventionProtocol but rather only the necessary reference to the Node inside of it.
So how do we solve this?
Implementation
It appears the cleanest way to capture visually the dependency relationship in the case described above would be to create a sort of edge between the Node and the
Iterable
structure that references it, which will in turn showcase the dependency relationship with it's listener node.Therefore, we loop over all the Node elements of an
Iterable
and visualize them recursively. Then, we create a gray dashed edge with theiIerable
structure directly. And the end result looks like this: