Closed zhiaos closed 1 year ago
what is the hidden states by their definition? is it the input going into each transformer block? (or each attention / feedforward block) ?
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)
x = torch.randint(0, 256, (1, 1024))
x, intermediates = model(x, return_intermediates = True) # (1, 1024, 20000)
intermediates.layer_hiddens # length 25 - 24 blocks (attention, feedforward alternating) + output of last block
@zhiaos let me know if layer_hiddens
has what you need
i can clean up the naming at a later version
Thanks @lucidrains!
Based on their documentation, hidden states represent the outputs of each layer -- 12 (Attention+Residual&Norm + FF + Residual&Norm) . So there are 12 of them. I think the layer_hiddens
returns 12 [ x, y] + z,
x - the input to Attention
y - the input to FF
z is Residual_Norm(last_y)
@zhiaos got it, yea, i need to return after each block, since this framework supports arbitrary ordering of attention and feedforwards (say a f f a f f
)
@lucidrains so it is not possible to get the output of each block based on current implementation?
@zhiaos the outputs before the prenorm residual is added? are you sure that's what you want?
@zhiaos the layer_hiddens
are the inputs to each block, which should be the output of the previous block out + residual
- or norm(out + residual)
for post norm
@lucidrains I see. I guess I just need to select those inputs to attentions to match hidden_states.
For Bert like model:
Is it possible to return hidden states from all layers similar to Huggingface BertModel when setting
output_hidden_states=True
(Huggingface).I tried
But I believe
intermediates.hiddens
is not exactly the hidden states for each layer.