Currently, we’ve implemented an edge as going between two nodes, but really an edge goes between a parent out hook and a child in hook. If we redesign Edge in this way, it should be fairly straightforward to implement edge indexing.
Proposed changes:
Refactor Node to be based on hook names rather than tied to specific blocks.
Refactor Edge to use the new hooks.
Corresponding changes in build_graph.
A benefit of this is that we can greatly simplify the logic of GraphIndexer.
class Node:
hook: HookName
@property
def requires_act(self): bool
return False
@property
def requires_grad(self): bool
return False
class SrcNode(Node):
@property
def requires_act(self): bool
return True
class DestNode(Node):
@property
def requires_grad(self): bool
return False
class Edge:
parent: Node
child: Node
## Functions to go in build.py
def add_input_node(graph, layer):
...
# Replaces MLPNode
def add_mlp_node(graph, layer):
...
# Replaced AttentionNode
def add_attn_node(graph, layer):
...
Currently, we’ve implemented an edge as going between two nodes, but really an edge goes between a parent out hook and a child in hook. If we redesign Edge in this way, it should be fairly straightforward to implement edge indexing.
Proposed changes:
Node
to be based on hook names rather than tied to specific blocks.Edge
to use the new hooks.build_graph
.A benefit of this is that we can greatly simplify the logic of
GraphIndexer
.