dtch1997 / sae-eap

Edge attribution patching with SAEs
0 stars 0 forks source link

[Proposal] Support multiple nodes per hook point #5

Closed dtch1997 closed 1 week ago

dtch1997 commented 1 week ago

Currently we implicitly assume one node per hook point; however this is not the case.

Proposed implementation:

Example code:

# Compute per-hook activations, gradients.  
acts_per_hook, grads_per_hook = compute_activations_and_gradients_simple(
    model, handler
)

# Convert this to per-node acts, grads. 
graph_acts, graph_grads = ... 

scores = compute_attribution_scores(
    graph_acts, graph_grads, model.cfg, aggregation=aggregation
)
dtch1997 commented 1 week ago

Closed in ae8b6bead66a477bc86f7693835a8042d8df502a