AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
437 stars 47 forks source link

Support showing logit/tuned lens predictions in the middle of layers (e.g. after attention/before MLP) #104

Closed williamrs-openai closed 1 year ago

williamrs-openai commented 1 year ago

It would be useful in figuring out when to attribute changes to MLP layer vs. attention heads. Ideally would train seperate tuned lens for intermediate residual stream, but maybe still helpful if you use the same lens as at the start of the layer.

levmckinney commented 1 year ago

Thank you for the suggestion. Early in the development of the project we in fact did support this. This was dropped because we switched over to studying models using the GPTNeoX architecture. These models like the pythia series run the attention and MLP layers in parallel to improve throughput during training. Thus, the intermediate residual stream does not actually exist in the models compute graph.

While we don't support training separate lenses for each subcomponent, lenses do transfer relatively well between layers (see the paper) i.e. they generally still out preform the logit lens, in terms of kl to the model logits.

On the GPT2, OPT, and LLaMA family models this mid-point in the residual stream does exist. Once #103 merges it will be very easy to produce tuned lens and logit lens representations at these points.

Here is a quick demo of how you would do this on gp2-small, though It works better on the opt models. https://colab.research.google.com/drive/1IHA4OEFsX346468h9Sd5cyCGzrs2WKwY?usp=sharing

williamrs-openai commented 1 year ago

Thanks for putting together the notebook

levmckinney commented 1 year ago

I'm glad it was helpful.

I don't think we will be adding support for training lenses on individual subcomponents in the near future. So, I'm going to mark this as resolved once #103 merges.