lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Bert token type embedding #213

Closed eyalmazuz closed 8 months ago

eyalmazuz commented 8 months ago

I was looking at the example of the Bert you give in the code but unlike the original paper I didn't see in the source code a way to add token type embedding

a solution is to manage myself by doing something like

import torch
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Encoder

class Bert(nn.Module):
    def __init__(self, num_tokens, num_types, dim): 
        model = ContinuousTransformerWrapper(
            dim_in = dim,
            dim_out = dim,
            max_seq_len = 1024
            attn_layers = Encoder(
                dim = dim,
                depth = 12,
                heads = 8
            )
        )
        self.type_emb = nn.Embedding(num_types, dim)
        self.token_emb = nn.Embedding(num_tokens, dim)

    def forward(self, tokens, types, mask):
        emb = self.token_emb(tokens)
        type_emb = self.type_emb(types)

        out = self.model(emb + type_emb, mask)

        return out

but it seems weird to handle the word embedding matrix myself when I can just used the regular TransformerWrapper is there a way to add token type embedding into the model so I could just say something like:

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    max_token_type = 3,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()

model(x, token_type = token_type, mask = mask) # (1, 1024, 20000)
lucidrains commented 8 months ago

@eyalmazuz hey Eyal! thanks for bringing this up

do you want to see if the following works for you in the latest version?

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    embed_num_tokens = dict(type = 5),
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

x = torch.randint(0, 256, (1, 1024))
types = torch.randint(0, 5, (1, 1024))

logits = model(x, embed_ids = dict(type = types))
logits.shape # (1, 1024, 20000)
eyalmazuz commented 8 months ago

@lucidrains Hey Phil! Thanks for the quick response and fix! I think that solution is great and could work for me

Thanks again for the feature I'll close the issue now