stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
Apache License 2.0
589 stars 55 forks source link

[Suggestion]: Support Causal Tracing for LLaMA model #174

Open aryopg opened 1 month ago

aryopg commented 1 month ago

Suggestion / Feature Request

I've tried modifying the embed_to_distrib function in pyvene/models/ to also support llama models as such:

def embed_to_distrib(model, embed, log=False, logits=False):
    """Convert an embedding to a distribution over the vocabulary"""
    if "gpt2" in model.config.architectures[0].lower():
        with torch.inference_mode():
            vocab = torch.matmul(embed, model.wte.weight.t())
            if logits:
                return vocab
            return lsm(vocab) if log else sm(vocab)
    elif "llama" in model.config.architectures[0].lower():
        with torch.inference_mode():
            vocab = model.lm_head(embed)
            if logits:
                return vocab
            return lsm(vocab) if log else sm(vocab)

It seems to work fine when doing causal tracing (see images below):

Would this be the correct approach to do so on Llama model, and would it be of interest for pyvene?