vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.68k stars 4.08k forks source link

hidden-states from final (or middle layers) #5406

Open janphilippfranken opened 3 months ago

janphilippfranken commented 3 months ago

🚀 The feature, motivation and pitch

I am trying to extract hidden states from the final layer of llama3-8b (i.e., the final batch_size, seq_length, n_emb vector before computing the logits). Would it be possible to add this functionality (i.e., access to hidden states similar to transformers ouput_hidden_states)? Thank you!

Alternatives

HuggingFace Transformers, but this is too slow.

Additional context

I am trying to train a SAE/linear probe on hidden states from llama3.

jeejeelee commented 3 months ago

Maybe you could consider utilizing nn.Module's hook.