ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
356 stars 34 forks source link

Conditional Context #183

Closed AdamBelfki3 closed 1 month ago

AdamBelfki3 commented 1 month ago

New feature! You can now create a Conditional context where the execution of all the interventions defined within its body is contingent upon the evaluation of its conditional value.

Simply use the with statement on any Intervention Proxy and it's value will be evaluated into a boolean.

with model.trace(input) as tracer:
    num = tracer.apply(int, 5)
    with num > 0:
        out = model.output.save()

Tracer

Examples

(Conditional Iteration)

with model.session() as session:
    l = session.apply(list).save()
    with session.iter([0, 1, 2]) as (item, iterator):
        with item % 2 == 0:
            l.append(item)
"""
Result: l -> [0, 2] 
"""

Iterator

(Nested Conditionals)

with model.trace(input) as tracer:
    num = tracer.apply(int, 5)
    with num > 0:
        l2_out = model.layer2.output.save()

        with num > 0:
            l2_out[:] = 0

        with num < 0:
            l2_out[:] = 1
"""
Result: l2_out -> tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
"""

Tracer

Implementation

Conditional Execution

The Conditional context condition is applied on a certain Node defined within its body, by assigning it's conditional dependency to the ConditionalProtocol node associated with the context creation. Assigning a conditional dependency increases the total count of dependencies for a node, which means that it only executes if the evaluation of the ConditionalProtocol node is resulted in True.

Conditional Manager

An Intervention Graph can contain multiple ConditionalProtocol nodes, which are managed by its ConditionaManager attachment. The Conditional Manager keeps tracks of all ConditionalProtocols node as well as all the Nodes conditioned by them, which is useful for the graph optimization heuristic discussed below.

InterventionProtocol and BridgeProtocol

Since the InterventionProtocol and BridgeProtocol are root node because of their get property, they are not bound to any conditional context, but rather the interventions dependent on them are.

If aBridgeProtocol node is conditioned, then the locks it creates on the external value will never be removed.

If an InterventionProtocol node is conditioned, then all the execution defined on it after the conditional body will become conditioned unwillingly.

False Condition

If the condition value of a Conditional is evaluated to False, then the effect of visiting all Conditional Intervention of this Conditional must be propagated within the graph, without executing them.

All conditioned nodes are updated recursively, such that their dependencies get their listener count decremented.

Node and Arg Dependency don't share the same conditional argument

We apply a heuristic to simplify the graph by excluding unnecessary assignment of conditional dependencies. This is relevant in the case that if two nodes - A and B - are defined within a Conditional context, and node B is dependent on node A, then only node A needs to assign the conditional dependency to itself and node B will inherit it by default.


with model.trace(input) as tracer:
    num = tracer.apply(int, 5)
    with num > 0:
        l1_out = model.layer1.output
        l1_out = l1_out * 2
        l1_out[:, 0] = 0
        l1_out.save()

Tracer

vs.

Tracer

Limitations

Use of Boolean Operators - AND, OR, NOT, ...