lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Classification with x-transformers #264

Open RyanKim17920 opened 2 months ago

RyanKim17920 commented 2 months ago

Added cls token/pooling option for NLP based full text classification

lucidrains commented 2 months ago

@RyanKim17920 do you want to try the latest changes and see if that's enough?

lucidrains commented 2 months ago

@RyanKim17920 hey Ryan, sorry for hijacking your efforts, just that the project is at a size where things need to be a bit more particular

your example should run now as

import torch
from torch import nn

from x_transformers import (
    TransformerWrapper,
    Encoder
)

# CLS token test
transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=2, # num_classes 
    use_cls_token=True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])

print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)

print(loss)

# BCE cls token

transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=1, # num_classes 
    use_cls_token=True,
    squeeze_out_last_dim = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()

print(x.shape)
logits = transformer(x).squeeze()
loss = nn.BCEWithLogitsLoss()(logits, y)

print(loss)

# pooling test
transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=2, # num_classes 
    average_pool_embed = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])

print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)

print(loss)

# pooling BCE test

# pooling test
transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=1, # num_classes 
    average_pool_embed = True,
    squeeze_out_last_dim = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()

print(x.shape)
logits = transformer(x).squeeze()
print(logits.shape)
loss = nn.BCEWithLogitsLoss()(logits, y)

print(loss)

# normal test 

transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=2, # num_classes 
    average_pool_embed = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (1, 10))
y = torch.tensor([0])

print(x.shape)
logits = transformer(x)
print(logits.shape)
RyanKim17920 commented 1 month ago

Thank you for the improvements you've already made to my original additions. I noticed that the test/x_transformers are outdated, so those changes aren't needed anymore. However, I believe the example I provided could still be valuable. It demonstrates the usage of the NLP classification with a well-known dataset, which might be useful for users to understand how to implement it while getting a high 90% validation accuracy.

Would it be possible to add the example to the repository?