vgel / repeng

A library for making RepE control vectors
https://vgel.me/posts/representation-engineering/
MIT License
435 stars 31 forks source link

Computing the difference vectors for PCA #28

Closed r3ndd closed 2 months ago

r3ndd commented 3 months ago

In the original repeng paper they mostly described the unsupervised version of PCA, where they randomly paired the hidden state vectors and performed PCA on these pairs. This is fundamentally different from the supervised implementation used here, and after reading the code I want to know if I am misunderstanding what your implementation is doing or if there is potential to improve your PCA methodology.

In the contrastive approach, we assume the hidden state vectors belong to one of two sets, A or B. When doing unsupervised PCA, items from these sets are randomly paired together without regards to to their label, e.g. the possible pair types are (Ai, Aj), (Bi, Bj), (Ai, Bj), (Bi, Aj). For simplicity I'll call these AA, BB, AB, and BA. If we assume vectors cluster closely with other vectors from their set, then the difference vectors for

where X is a vector that points from set B to set A. Doing PCA on these difference vectors should then give you ~X as your first principal component.

However, it appears that in your supervised method you only have pairs (Ai, Bj), or AB. Taking the difference of these vectors should give you vectors that all center around a point X, with very low variance. In theory when these points are projected onto any arbitrary line they will all project to the same point, meaning PCA should not work very well. If this is what you are doing then one explanation for why it still works is that not all hidden states actually encode the representations you are trying to extract, making this the primary source of variance PCA is picking up on.

This question came up before in the realm of word vectors: https://stackoverflow.com/questions/48019843/pca-on-word2vec-embeddings The simple fix for supervised PCA is to first compute a center point for each pair: Ck = (Ai + Bj)/2. Then, your difference vectors become Ai - Ck and Bj - Ck, such that you get two opposing difference vectors for each pair. Doing PCA on all of these difference vectors should then work a lot better since the variance due to the representations should be a lot larger.

vgel commented 2 months ago

Fixed by #34