lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

return hidden states of all layers #177

Closed zhiaos closed 1 year ago

zhiaos commented 1 year ago

For Bert like model:

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

Is it possible to return hidden states from all layers similar to Huggingface BertModel when setting output_hidden_states=True (Huggingface).

I tried

x = torch.randint(0, 256, (1, 1024)).cuda()
logit, intermediates = model(x, return_intermediates=True)

But I believe intermediates.hiddens is not exactly the hidden states for each layer.

lucidrains commented 1 year ago

what is the hidden states by their definition? is it the input going into each transformer block? (or each attention / feedforward block) ?

lucidrains commented 1 year ago
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

x = torch.randint(0, 256, (1, 1024))

x, intermediates = model(x, return_intermediates = True) # (1, 1024, 20000)

intermediates.layer_hiddens # length 25 - 24 blocks (attention, feedforward alternating) + output of last block
lucidrains commented 1 year ago

@zhiaos let me know if layer_hiddens has what you need

i can clean up the naming at a later version

zhiaos commented 1 year ago

Thanks @lucidrains!

Based on their documentation, hidden states represent the outputs of each layer -- 12 (Attention+Residual&Norm + FF + Residual&Norm) . So there are 12 of them. I think the layer_hiddens returns 12 [ x, y] + z,
x - the input to Attention y - the input to FF z is Residual_Norm(last_y)

lucidrains commented 1 year ago

@zhiaos got it, yea, i need to return after each block, since this framework supports arbitrary ordering of attention and feedforwards (say a f f a f f)

zhiaos commented 1 year ago

@lucidrains so it is not possible to get the output of each block based on current implementation?

lucidrains commented 1 year ago

@zhiaos the outputs before the prenorm residual is added? are you sure that's what you want?

lucidrains commented 1 year ago

@zhiaos the layer_hiddens are the inputs to each block, which should be the output of the previous block out + residual - or norm(out + residual) for post norm

zhiaos commented 1 year ago

@lucidrains I see. I guess I just need to select those inputs to attentions to match hidden_states.