ArthurConmy / Automatic-Circuit-Discovery

MIT License
175 stars 36 forks source link

Loading model fails in tutorial #97

Open velezbeltran opened 8 months ago

velezbeltran commented 8 months ago

Hello!

I have been working on the ACDC_Main_Demo.ipynb repo and I am currently facing an issue where if I attempt to load the model from a subgraph I get an error. In particular I attempt the following steps.

  1. I run the notebook as is adding one additional line where I save the output of the subgraph using the command below

    exp.save_subgraph(
    save_path,
    return_it=True,
    )
  2. I run all of the cells except for the cell containing the code

    
    for i in range(args.max_num_epochs):
    exp.step(testing=False)
    
    show(
        exp.corr,
        f"ims/img_new_{i+1}.png",
        show_full_index=False,
    )
    
    if IN_COLAB or ipython is not None:
        # so long as we're not running this as a script, show the image!
        display(Image(f"ims/img_new_{i+1}.png"))
    
    print(i, "-" * 50)
    print(exp.count_no_edges())
    
    if i == 0:
        exp.save_edges("edges.pkl")
    
    if exp.current_node is None or SINGLE_STEP:
        break

exp.save_edges("another_final_edges.pkl")

if USING_WANDB: edges_fname = f"edges.pth" exp.save_edges(edges_fname) artifact = wandb.Artifact(edges_fname, type="dataset") artifact.add_file(edges_fname) wandb.log_artifact(artifact) os.remove(edges_fname) wandb.finish()


3. I load the subgraph using 

load using torch

circuit = t.load(subgraph_path) exp.load_subgraph(circuit)


If I do this I get the following assertion error:

AssertionError: Ensure that the dictionary includes exactly the correct keys... e.g missing [('blocks.1.hook_q_input', (None, None, 0), 'blocks.0.attn.hook_result', (None, None, 1))] and has excess stuff []



What could be causing this? Am I doing something wrong? Alternatively, what is the standard way of loading in circuits?
Also, if I do run the cell that contains the `.step()` method I don't have this issue. 

Thank you!
Nicolas
rhaps0dy commented 8 months ago

Possibly the TransformerLens version you're using is different from the one that was used to save the hypothesis, so the hook names are different. What's the list of edges from exp.corr.all_edges().keys() ?

velezbeltran commented 8 months ago

Thanks for your lighting fast response!

Before running the `.step()` function block ict_keys([('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 7]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 5]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 4]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 3]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 2]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 1]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 0]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_resid_post', [:], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_q', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_k', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_v', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_v', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_q', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_k', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_v', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_q', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_k', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_v', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_q', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_k', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_v', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_q', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_k', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_v', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_q', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_k', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_v', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_q', [:, :, 0]), ('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_k', [:, :, 0]), ('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_v', [:, :, 0]), ('blocks.1.attn.hook_q', [:, :, 7], 'blocks.1.hook_q_input', [:, :, 7]), ('blocks.1.attn.hook_q', [:, :, 6], 'blocks.1.hook_q_input', [:, :, 6]), ('blocks.1.attn.hook_q', [:, :, 5], 'blocks.1.hook_q_input', [:, :, 5]), ('blocks.1.attn.hook_q', [:, :, 4], 'blocks.1.hook_q_input', [:, :, 4]), ('blocks.1.attn.hook_q', [:, :, 3], 'blocks.1.hook_q_input', [:, :, 3]), ('blocks.1.attn.hook_q', [:, :, 2], 'blocks.1.hook_q_input', [:, :, 2]), ('blocks.1.attn.hook_q', [:, :, 1], 'blocks.1.hook_q_input', [:, :, 1]), ('blocks.1.attn.hook_q', [:, :, 0], 'blocks.1.hook_q_input', [:, :, 0]), ('blocks.1.attn.hook_k', [:, :, 7], 'blocks.1.hook_k_input', [:, :, 7]), ('blocks.1.attn.hook_k', [:, :, 6], 'blocks.1.hook_k_input', [:, :, 6]), ('blocks.1.attn.hook_k', [:, :, 5], 'blocks.1.hook_k_input', [:, :, 5]), ('blocks.1.attn.hook_k', [:, :, 4], 'blocks.1.hook_k_input', [:, :, 4]), ('blocks.1.attn.hook_k', [:, :, 3], 'blocks.1.hook_k_input', [:, :, 3]), ('blocks.1.attn.hook_k', [:, :, 2], 'blocks.1.hook_k_input', [:, :, 2]), ('blocks.1.attn.hook_k', [:, :, 1], 'blocks.1.hook_k_input', [:, :, 1]), ('blocks.1.attn.hook_k', [:, :, 0], 'blocks.1.hook_k_input', [:, :, 0]), ('blocks.1.attn.hook_v', [:, :, 7], 'blocks.1.hook_v_input', [:, :, 7]), ('blocks.1.attn.hook_v', [:, :, 6], 'blocks.1.hook_v_input', [:, :, 6]), ('blocks.1.attn.hook_v', [:, :, 5], 'blocks.1.hook_v_input', [:, :, 5]), ('blocks.1.attn.hook_v', [:, :, 4], 'blocks.1.hook_v_input', [:, :, 4]), ('blocks.1.attn.hook_v', [:, :, 3], 'blocks.1.hook_v_input', [:, :, 3]), ('blocks.1.attn.hook_v', [:, :, 2], 'blocks.1.hook_v_input', [:, :, 2]), ('blocks.1.attn.hook_v', [:, :, 1], 'blocks.1.hook_v_input', [:, :, 1]), ('blocks.1.attn.hook_v', [:, :, 0], 'blocks.1.hook_v_input', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_q', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_k', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_v', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_q', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_k', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_v', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_q', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_k', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_v', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_q', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_k', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_v', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_q', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_k', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_v', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_q', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_k', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_v', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_q', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_k', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_v', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]), ('blocks.0.attn.hook_q', [:, :, 7], 'blocks.0.hook_q_input', [:, :, 7]), ('blocks.0.attn.hook_q', [:, :, 6], 'blocks.0.hook_q_input', [:, :, 6]), ('blocks.0.attn.hook_q', [:, :, 5], 'blocks.0.hook_q_input', [:, :, 5]), ('blocks.0.attn.hook_q', [:, :, 4], 'blocks.0.hook_q_input', [:, :, 4]), ('blocks.0.attn.hook_q', [:, :, 3], 'blocks.0.hook_q_input', [:, :, 3]), ('blocks.0.attn.hook_q', [:, :, 2], 'blocks.0.hook_q_input', [:, :, 2]), ('blocks.0.attn.hook_q', [:, :, 1], 'blocks.0.hook_q_input', [:, :, 1]), ('blocks.0.attn.hook_q', [:, :, 0], 'blocks.0.hook_q_input', [:, :, 0]), ('blocks.0.attn.hook_k', [:, :, 7], 'blocks.0.hook_k_input', [:, :, 7]), ('blocks.0.attn.hook_k', [:, :, 6], 'blocks.0.hook_k_input', [:, :, 6]), ('blocks.0.attn.hook_k', [:, :, 5], 'blocks.0.hook_k_input', [:, :, 5]), ('blocks.0.attn.hook_k', [:, :, 4], 'blocks.0.hook_k_input', [:, :, 4]), ('blocks.0.attn.hook_k', [:, :, 3], 'blocks.0.hook_k_input', [:, :, 3]), ('blocks.0.attn.hook_k', [:, :, 2], 'blocks.0.hook_k_input', [:, :, 2]), ('blocks.0.attn.hook_k', [:, :, 1], 'blocks.0.hook_k_input', [:, :, 1]), ('blocks.0.attn.hook_k', [:, :, 0], 'blocks.0.hook_k_input', [:, :, 0]), ('blocks.0.attn.hook_v', [:, :, 7], 'blocks.0.hook_v_input', [:, :, 7]), ('blocks.0.attn.hook_v', [:, :, 6], 'blocks.0.hook_v_input', [:, :, 6]), ('blocks.0.attn.hook_v', [:, :, 5], 'blocks.0.hook_v_input', [:, :, 5]), ('blocks.0.attn.hook_v', [:, :, 4], 'blocks.0.hook_v_input', [:, :, 4]), ('blocks.0.attn.hook_v', [:, :, 3], 'blocks.0.hook_v_input', [:, :, 3]), ('blocks.0.attn.hook_v', [:, :, 2], 'blocks.0.hook_v_input', [:, :, 2]), ('blocks.0.attn.hook_v', [:, :, 1], 'blocks.0.hook_v_input', [:, :, 1]), ('blocks.0.attn.hook_v', [:, :, 0], 'blocks.0.hook_v_input', [:, :, 0]), ('blocks.0.hook_q_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:])])

After the step function

After running the `.step()` function block dict_keys([('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_v', [:, :, 6]), ('blocks.1.attn.hook_q', [:, :, 6], 'blocks.1.hook_q_input', [:, :, 6]), ('blocks.1.attn.hook_k', [:, :, 6], 'blocks.1.hook_k_input', [:, :, 6]), ('blocks.1.attn.hook_v', [:, :, 6], 'blocks.1.hook_v_input', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]), ('blocks.0.attn.hook_v', [:, :, 0], 'blocks.0.hook_v_input', [:, :, 0]), ('blocks.0.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:])])

I don't think the issue is that the TransformerLens versions are different because I can reproduce this all from the same notebook in colab.

Thank you

rhaps0dy commented 8 months ago

Turns out the explanation is: the ACDC algorithm literally removes edges (i.e. removes them from the correspondence dictionaries), as opposed to saying edge.present = False. That makes it fail when loading.

The loading code should be changed to fix this.

Iust1n2 commented 3 months ago

@velezbeltran I'm curious if you would be so kind to share the working code for loading the subgraph weights edges.pth for inference. I did not quite catch from @rhaps0dy what the modification should be and where. Thanks!