jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
61 stars 15 forks source link

Implement AVEC in the interpretability app #72

Closed jbloomAus closed 1 year ago

jbloomAus commented 1 year ago

https://github.com/montemac/algebraic_value_editing/blob/main/scripts/basic_functionality.py

Implement it as part of DTI

jbloomAus commented 1 year ago

Questions:

They run a two token forward pass. I think I should run a one token forward pass since I don't have EOS. I do have the padding tokens though, so could pad these tokens to do the forward pass. I would then need to align it as we went forward which seems doable.

def get_resid_pre(prompt: str, layer: int):
    name = f"blocks.{layer}.hook_resid_pre"
    cache, caching_hooks, _ = model.get_caching_hooks(lambda n: n == name)
    with model.hooks(fwd_hooks=caching_hooks):
        _ = model(prompt)
    return cache[name]
def ave_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return  # caching in model.generate for new tokens

    # We only add to the prompt (first call), not the generated tokens.
    ppos, apos = resid_pre.shape[1], act_diff.shape[1]
    assert apos <= ppos, f"More mod tokens ({apos}) then prompt tokens ({ppos})!"

    # add to the beginning (position-wise) of the activations
    resid_pre[:, :apos, :] += coeff * act_diff
jbloomAus commented 1 year ago

How do I apply this procedure in DTI?

I think the easiest way for me to do this, is to do it which the AVEC code. I could parameterise it so we can get more info on the outcomes (eg, layer, head etc)