Closed williamrs-openai closed 1 year ago
Thank you for the suggestion. Early in the development of the project we in fact did support this. This was dropped because we switched over to studying models using the GPTNeoX architecture. These models like the pythia
series run the attention and MLP layers in parallel to improve throughput during training. Thus, the intermediate residual stream does not actually exist in the models compute graph.
While we don't support training separate lenses for each subcomponent, lenses do transfer relatively well between layers (see the paper) i.e. they generally still out preform the logit lens, in terms of kl to the model logits.
On the GPT2, OPT, and LLaMA family models this mid-point in the residual stream does exist. Once #103 merges it will be very easy to produce tuned lens and logit lens representations at these points.
Here is a quick demo of how you would do this on gp2-small, though It works better on the opt models. https://colab.research.google.com/drive/1IHA4OEFsX346468h9Sd5cyCGzrs2WKwY?usp=sharing
Thanks for putting together the notebook
I'm glad it was helpful.
I don't think we will be adding support for training lenses on individual subcomponents in the near future. So, I'm going to mark this as resolved once #103 merges.
It would be useful in figuring out when to attribute changes to MLP layer vs. attention heads. Ideally would train seperate tuned lens for intermediate residual stream, but maybe still helpful if you use the same lens as at the start of the layer.