ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
356 stars 34 forks source link

Unique Bridge Nodes #188

Closed AdamBelfki3 closed 1 month ago

AdamBelfki3 commented 1 month ago

Adding functionality for optimizing bridged Intervention Graphs.

Motivation

Previously, a BridgeProtocol node was created at each reference of an external proxy within the same graph. But optimally, all operations on the same proxy should all be linked to a single corresponding BridgeProtocol node.

Example

from collections import OrderedDict
from nnsight import NNsight
import torch

input_size = 5
hidden_dims = 10
output_size = 2

net = torch.nn.Sequential(
    OrderedDict(
        [
            ("layer1", torch.nn.Linear(input_size, hidden_dims)),
            ("layer2", torch.nn.Linear(hidden_dims, output_size)),
        ]
    )
).requires_grad_(False)

input = torch.rand((1, input_size))

with model.session() as session:
    with model.trace(input) as tracer:
        l1_out = model.layer1.output

    with model.trace(input) as tracer_2:
        l1_out_double = l1_out * 2

        l1_out_sum = torch.sum(l1_out)

        tracer_2.graph.vis("Tracer_2")

Tracer_2

Implementation

The Bridge structure tracks the bridge nodes created for a certain proxy, for every context it was externally referenced from.

Preprocessing of a Node's arguments during instantiation no longer automatically creates a BridgeProtocol node in the case of an external reference to an argument. The BridgeProtocol handles directly bridging a node (creating a local reference) by checking first if the external node has been bridged previously before creating a new BridgeProtocol node.

When the UpdateProtocol is called on an external proxy node, the value of its local BridgeProtocol reference is also updated to reflect the new change.

Results

Tracer_2