Adding functionality for optimizing bridged Intervention Graphs.
Motivation
Previously, a BridgeProtocol node was created at each reference of an external proxy within the same graph. But optimally, all operations on the same proxy should all be linked to a single corresponding BridgeProtocol node.
Example
from collections import OrderedDict
from nnsight import NNsight
import torch
input_size = 5
hidden_dims = 10
output_size = 2
net = torch.nn.Sequential(
OrderedDict(
[
("layer1", torch.nn.Linear(input_size, hidden_dims)),
("layer2", torch.nn.Linear(hidden_dims, output_size)),
]
)
).requires_grad_(False)
input = torch.rand((1, input_size))
with model.session() as session:
with model.trace(input) as tracer:
l1_out = model.layer1.output
with model.trace(input) as tracer_2:
l1_out_double = l1_out * 2
l1_out_sum = torch.sum(l1_out)
tracer_2.graph.vis("Tracer_2")
Implementation
The Bridge structure tracks the bridge nodes created for a certain proxy, for every context it was externally referenced from.
Preprocessing of a Node's arguments during instantiation no longer automatically creates a BridgeProtocol node in the case of an external reference to an argument. The BridgeProtocol handles directly bridging a node (creating a local reference) by checking first if the external node has been bridged previously before creating a new BridgeProtocol node.
When the UpdateProtocol is called on an external proxy node, the value of its local BridgeProtocol reference is also updated to reflect the new change.
Adding functionality for optimizing bridged Intervention Graphs.
Motivation
Previously, a
BridgeProtocol
node was created at each reference of an external proxy within the same graph. But optimally, all operations on the same proxy should all be linked to a single correspondingBridgeProtocol
node.Example
Implementation
The
Bridge
structure tracks the bridge nodes created for a certain proxy, for every context it was externally referenced from.Preprocessing of a
Node
's arguments during instantiation no longer automatically creates a BridgeProtocol node in the case of an external reference to an argument. TheBridgeProtocol
handles directly bridging a node (creating a local reference) by checking first if the external node has been bridged previously before creating a newBridgeProtocol
node.When the
UpdateProtocol
is called on an external proxy node, the value of its localBridgeProtocol
reference is also updated to reflect the new change.Results