dtch1997 / sae-eap

Edge attribution patching with SAEs
0 stars 0 forks source link

[Proposal] Refactor Node to be based on hooks #3

Closed dtch1997 closed 1 week ago

dtch1997 commented 1 week ago

Currently, we’ve implemented an edge as going between two nodes, but really an edge goes between a parent out hook and a child in hook. If we redesign Edge in this way, it should be fairly straightforward to implement edge indexing.

Proposed changes:

A benefit of this is that we can greatly simplify the logic of GraphIndexer.

class Node:
    hook: HookName

    @property
    def requires_act(self): bool
      return False

    @property 
    def requires_grad(self): bool
        return False

class SrcNode(Node):
  @property
  def requires_act(self): bool 
    return True 

class DestNode(Node):
  @property
  def requires_grad(self): bool
    return False

class Edge:
  parent: Node
  child: Node

## Functions to go in build.py

def add_input_node(graph, layer):
  ... 

# Replaces MLPNode
def add_mlp_node(graph, layer):
  ... 

# Replaced AttentionNode
def add_attn_node(graph, layer):
  ...
dtch1997 commented 1 week ago

closed by #4