OpenMOSS / Language-Model-SAEs

For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research.
45 stars 8 forks source link

Notation Clarification #55

Closed gboxo closed 1 month ago

gboxo commented 1 month ago

I would like to get some clarification on the implementation details of the codebase, and it's relation to the experiment performed on https://arxiv.org/pdf/2405.13868.

My understanding is that:

When performing Hierarchical Attribution for a local circuit, (eg. the target is L1.A11421 ), the linear computational graph G, will be composed of:

Also:

Thanks.

dest1n1s commented 1 month ago

Thank you for following our work! Your understanding of Hierarchical Attribution and Transcoders is correct.

The Attention SAEs are trained on the outputs of the attention block, that corresponds to hook_z in TL notation.

The Attention SAEs are trained on hook_attn_outs, which differ from hook_z by a linear transformation.

If you have any further questions about our work or the codebase, feel free to re-open this issue!

gboxo commented 1 month ago

Hello, thank-you for your helpful answer.

Changing the setup so the Attention SAE's are trained on "hook_mlp_out" (specifically the ones trained by OAI). Still using the transcoders trained by pchlenski again trained from (ln2.normalized --> hook_mlp_out) a major problem still persist.

The attribution of the leaf nodes (transcoders and Attn SAE features) add up to more than the activation of the target node.

The code looks like this.

` candidates = None all_saes = saes + tcs

with apply_sae(model, all_saes):
    with model.hooks([(f"blocks.{i}.attn.hook_attn_scores", detach_hook) for i in range(12)]):
        attributor = HierachicalAttributor(model = model)

        target = Node("blocks.5.hook_attn_out.sae.hook_sae_acts_post",reduction="0.9.8506")

        if candidates is None:
            candidates = [Node(f"{sae.cfg.hook_name}.sae.hook_sae_acts_post") for sae in saes] + [Node(f"{sae.cfg.out_hook_point}.sae.hook_hidden_post") for sae in tcs] + [Node(f"blocks.{i}.attn.hook_attn_scores") for i in range(6)]
        circuit = attributor.attribute(toks=toks, target=target, candidates=candidates, threshold=0)`

The cache also looks like this:

blocks.0.attn.hook_attn_scores
blocks.0.hook_attn_out.sae.hook_sae_acts_post
blocks.0.hook_mlp_out.sae.hook_hidden_post
blocks.1.attn.hook_attn_scores
blocks.1.hook_attn_out.sae.hook_sae_acts_post
blocks.1.hook_mlp_out.sae.hook_hidden_post
blocks.2.attn.hook_attn_scores
blocks.2.hook_attn_out.sae.hook_sae_acts_post
blocks.2.hook_mlp_out.sae.hook_hidden_post
blocks.3.attn.hook_attn_scores
blocks.3.hook_attn_out.sae.hook_sae_acts_post
blocks.3.hook_mlp_out.sae.hook_hidden_post
blocks.4.attn.hook_attn_scores
blocks.4.hook_attn_out.sae.hook_sae_acts_post
blocks.4.hook_mlp_out.sae.hook_hidden_post
blocks.5.attn.hook_attn_scores
blocks.5.hook_attn_out.sae.hook_sae_acts_post

The sum of the leaf nodes add up to 19.3 while the activation of the target node is just 6.5.

Is there anything obvious to you to explain the result?