I am trying to extract hidden states from the final layer of llama3-8b (i.e., the final batch_size, seq_length, n_emb vector before computing the logits). Would it be possible to add this functionality (i.e., access to hidden states similar to transformers ouput_hidden_states)? Thank you!
Alternatives
HuggingFace Transformers, but this is too slow.
Additional context
I am trying to train a SAE/linear probe on hidden states from llama3.
🚀 The feature, motivation and pitch
I am trying to extract hidden states from the final layer of llama3-8b (i.e., the final batch_size, seq_length, n_emb vector before computing the logits). Would it be possible to add this functionality (i.e., access to hidden states similar to transformers ouput_hidden_states)? Thank you!
Alternatives
HuggingFace Transformers, but this is too slow.
Additional context
I am trying to train a SAE/linear probe on hidden states from llama3.