ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
399 stars 37 forks source link

Intervention Graph Visualization Summary #171

Closed AdamBelfki3 closed 3 months ago

AdamBelfki3 commented 4 months ago

Overview

New Feature! Users can save graph visualizations generated from interventions defined in their experiments at any stage. This is particularly helpful to debug interventions and create visual summaries of your experiments to be shared with other researchers.

Usage

To use this feature, simply call .vis(…) on a Tracer, Iterator, or Session context at any point in time to visualize all the operations added to that point. Parameterize your call using the following arguments:

"""
title (str): File name and displayed title of the graph. Defaults to "graph".
path (Optional[str]): Directory path to save the graphic in. Defaults to current directory.
recursive (bool): If True, recursively visualizes all inner sub intervention graphs into the same visual. Defaults to False.
"""

Here's how this feature can be used to summarize interventions different context levels:

from collections import OrderedDict
from nnsight import NNsight
import torch

input_size = 5
hidden_dims = 10
output_size = 2

torch.manual_seed(423)

net = torch.nn.Sequential(
    OrderedDict(
        [
            ("layer1", torch.nn.Linear(input_size, hidden_dims)),
            ("layer2", torch.nn.Linear(hidden_dims, output_size))
        ]
    )
).requires_grad_(False)

input = torch.rand((1, input_size))
input2 = torch.rand((1, input_size))

model = NNsight(net)

with model.session() as sesh:
        with model.trace(input) as tracer_1:
            l1_out_t1 = model.layer1.output.save()
            l1_out_t1_cp = model.layer1.output.clone().save()
            model.layer1.output.stop()

        with sesh.iter([0, 1, 2]) as (item, iterator):
            with model.trace(input2) as tracer_2:
                l1_out_t2 = model.layer1.output.save()

            l1_out_t1_cp[:, item] = l1_out_t2[:, item] + l1_out_t1[:, item]

        sesh.vis(title="session", path="intervention-graphs/visual-example")
        tracer_1.vis(title="tracer_1", path="intervention-graphs/visual-example")
        tracer_2.vis(title="tracer_2", path="intervention-graphs/visual-example")
        iterator.vis(title="iterator", path="intervention-graphs/visual-example")

    print("Tracer_1 L1_Out: ", l1_out_t1)
    print("Tracer_1 L1_Out Copy: ", l1_out_t1_cp)
    print("Tracer_2 L1_Out: ", l1_out_t2)

Visualizations

session

tracer_1

iterator

tracer_2

Legend

Node shape color
InterventionProtocol $${\color{green}\textsf{green}}$$
GradProtocol $${\color{green}\textsf{green}}$$
BridgeProtocol $${\color{brown}\textsf{brown}}$$
ValueProtocol $${\color{blue}\textsf{blue}}$$
Value $${\color{grey}\textsf{grey}}$$
SwapProtocol $${\color{green}\textsf{green}}$$
LockProtocol $${\color{brown}\textsf{brown}}$$
Operation $${\color{black}\textsf{black}}$$
ApplyModuleProtocol $${\color{blue}\textsf{blue}}$$
LocalBackendExecuteProtocol $${\color{purple}\textsf{purple}}$$
EarlyStopProtocol $${\color{red}\textsf{red}}$$

Implementation

Protocol classes define a styles attribute that contains rendering data for the class node and its 'Value' arguments. Currently, separate custom rendering features of nodes include:

An example from the BridgeProtocol styling rules:

styles: Dict[str, any] = {"node": {"color": "brown", "shape": "box"},
                   "arg": defaultdict(lambda: {"color": "gray", "shape": "box"}, 
                                           {0: {"color": "gray", "shape": "box", "style": "dashed"}}),
                   "arg_kname": defaultdict(lambda: None, {0: "graph_id"}), 
                   "edge": defaultdict(lambda: "solid", {0: "dashed"})}