stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
545 stars 46 forks source link

[Feature Request / Suggestion]: Capturing of the residual stream #154

Closed CoffeeVampir3 closed 2 months ago

CoffeeVampir3 commented 2 months ago

Suggestion / Feature Request

Hi, I'm attempting to implement the following paper in pyvene https://arxiv.org/pdf/2312.06681

The specifics of the method require capturing of the residual stream. I'm unsure if this is possible with the current library feature, but based on the documentation I saw no obvious way of doing this.

For clarity, here's transformer lens' residual cache https://github.com/neelnanda-io/TransformerLens/blob/be135a01745ab7796e8e560cf0498d5791857a93/transformer_lens/ActivationCache.py#L320

Cheers

aryamanarora commented 2 months ago

You can target the block_output at layer $n$ to capture the residual stream after the application of the $n$-th layer. Most of the papers that use pyvene intervene on the residual stream, e.g. mine.

If you're using pyvene on a custom model architecture (i.e. not supported in the library directly), I can help you set that up too! Only ~5 lines of code.

aryamanarora commented 2 months ago

Oh wait if you just want to collect the activations at those point, you want to use CollectActivation as one of the interventions in the config. You can find a minimal code example in pyvene 101, in the "Activation Collection" sections

CoffeeVampir3 commented 2 months ago

Ah, then I think I'm misunderstanding something crucial. I did set up a minimal activation collection (but I thought it was missing the residual stream info)

model_name = "/home/blackroot/Desktop/ortho/Meta-Llama-3-8B-Instruct" 
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
def probing_config(layer):
    config = pv.IntervenableConfig([{
        "component": f"model.layers[{layer}].mlp.output",
        "intervention_type": pv.CollectIntervention,
    }])
    return config

config = probing_config(15)
pv_model = pv.IntervenableModel(config, model=model)

Was I on track that

collected_activations = pv_model(base_input)[0][-1]
hidden_state = pv_model(base_input)[1][-1]

At this point it seemed like the collected activation weren't the residual stream, but I'm possibly just not understanding the returned values of the output. I suppose that brings me to the question, how can I interpret the values here, is their some resource I can into that has the structure of the values returned by pv_model(base_input)? I apologize if I'm asking a very rudimentary question.

Cheers

frankaging commented 2 months ago

@CoffeeVampir3 could you try to set the

"component": f"model.layers[{layer}].output",

instead of mlp.output? since the residual stream is the whole layer component output. thanks!

CoffeeVampir3 commented 2 months ago

Ah I see.

So given an architecture

    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )

The output of component.output is (eg. collected_activations = pv_model(base_input)[0][-1]) always the terminal activation for that layer, so in this case of "component": f"model.layers[{layer}].output", the output of the post_attention_layernorm for layer X? Yeah, indeed I had misunderstood how the library worked. Thanks very much for the clarification, I had overcomplicated things in my mind.

Cheers!