Closed eyalmazuz closed 8 months ago
@eyalmazuz hey Eyal! thanks for bringing this up
do you want to see if the following works for you in the latest version?
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
embed_num_tokens = dict(type = 5),
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)
x = torch.randint(0, 256, (1, 1024))
types = torch.randint(0, 5, (1, 1024))
logits = model(x, embed_ids = dict(type = types))
logits.shape # (1, 1024, 20000)
@lucidrains Hey Phil! Thanks for the quick response and fix! I think that solution is great and could work for me
Thanks again for the feature I'll close the issue now
I was looking at the example of the Bert you give in the code but unlike the original paper I didn't see in the source code a way to add token type embedding
a solution is to manage myself by doing something like
but it seems weird to handle the word embedding matrix myself when I can just used the regular
TransformerWrapper
is there a way to add token type embedding into the model so I could just say something like: