vgel / repeng

A library for making RepE control vectors
https://vgel.me/posts/representation-engineering/
MIT License
435 stars 31 forks source link

Take hidden states from last non-padding token when batching #38

Closed ohxh closed 3 weeks ago

ohxh commented 3 weeks ago

First of all, this is a really neat repo!

I noticed that batched_get_hiddens always takes hidden states from the last token in each sequence in a batch. Since the sequences are padded to the same length, this means that batching affects the hidden states for all but the longest sequence in each batch.

After this change, there's still some difference between the batched and non-batched hidden states, but I think that might be due to the model itself since batching changes the order of operations: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232

I've only tried this on llama-3-8b, I'm not sure if it will need changes to work on other models.

4 sequences, batch_size=4, old method:
[[-2.777  -2.205   3.318  ...  1.834   2.014   1.123 ]
 [ 1.21   -2.031   2.41   ...  1.883   0.391   1.652 ]
 [ 1.153  -1.737   2.281  ...  2.236   2.676   2.178 ]
 [ 1.25   -1.308   2.342  ...  0.9683  3.71    2.516 ]]
4 sequences, batch_size=4, new method:
[[-2.777   -2.205    3.318   ...  1.834    2.014    1.123  ]
 [ 0.852   -3.914    1.661   ...  1.693    0.828   -0.0934 ]
 [ 0.10767 -2.484   -1.208   ...  2.771    2.46     0.7217 ]
 [-1.701   -2.082    2.62    ...  1.927    2.334   -0.33   ]]
4 sequences, batch_size=1, old method:
[[-2.79    -2.2      3.314   ...  1.833    2.012    1.125  ]
 [ 0.8516  -3.912    1.659   ...  1.693    0.8306  -0.08746]
 [ 0.1023  -2.49    -1.211   ...  2.775    2.453    0.714  ]
 [-1.699   -2.084    2.61    ...  1.923    2.334   -0.3232 ]]
4 sequences, batch_size=1, new method:
[[-2.79    -2.2      3.314   ...  1.833    2.012    1.125  ]
 [ 0.8516  -3.912    1.659   ...  1.693    0.8306  -0.08746]
 [ 0.1023  -2.49    -1.211   ...  2.775    2.453    0.714  ]
 [-1.699   -2.084    2.61    ...  1.923    2.334   -0.3232 ]]
vgel commented 3 weeks ago

weird! i thought the tokenizers were left-padding by default... ah, mistral does...

>>> llama3_tokenizer(["x", "x x"], padding=True)
{'input_ids': [[128000, 87, 128001], [128000, 87, 865]], 'attention_mask': [[1, 1, 0], [1, 1, 1]]}
>>> mistral_tokenizer(["x", "x x"], padding=True)
{'input_ids': [[2, 1, 1318], [1, 1318, 1318]], 'attention_mask': [[0, 1, 1], [1, 1, 1]]}
ohxh commented 3 weeks ago

Oh huh… maybe an easier fix would be to force the tokenizer to always left pad

vgel commented 3 weeks ago

can you check "allow edits by maintainers" so i can make changes to this PR? image

vgel commented 3 weeks ago

Oh huh… maybe an easier fix would be to force the tokenizer to always left pad

yeah i was thinking that, but i think your approach is better because the user might want right-padding for whatever reason--better to not mess with their tokenizer instance if we can avoid it.

ohxh commented 3 weeks ago

I think it is checked already…

vgel commented 3 weeks ago

Glad I checked the PRs too, was just about to cut the 0.3 release so you just squeaked in!