Open RyanKim17920 opened 2 months ago
@RyanKim17920 do you want to try the latest changes and see if that's enough?
@RyanKim17920 hey Ryan, sorry for hijacking your efforts, just that the project is at a size where things need to be a bit more particular
your example should run now as
import torch
from torch import nn
from x_transformers import (
TransformerWrapper,
Encoder
)
# CLS token test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=2, # num_classes
use_cls_token=True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])
print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)
print(loss)
# BCE cls token
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=1, # num_classes
use_cls_token=True,
squeeze_out_last_dim = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()
print(x.shape)
logits = transformer(x).squeeze()
loss = nn.BCEWithLogitsLoss()(logits, y)
print(loss)
# pooling test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=2, # num_classes
average_pool_embed = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])
print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)
print(loss)
# pooling BCE test
# pooling test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=1, # num_classes
average_pool_embed = True,
squeeze_out_last_dim = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()
print(x.shape)
logits = transformer(x).squeeze()
print(logits.shape)
loss = nn.BCEWithLogitsLoss()(logits, y)
print(loss)
# normal test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=2, # num_classes
average_pool_embed = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (1, 10))
y = torch.tensor([0])
print(x.shape)
logits = transformer(x)
print(logits.shape)
Thank you for the improvements you've already made to my original additions. I noticed that the test/x_transformers are outdated, so those changes aren't needed anymore. However, I believe the example I provided could still be valuable. It demonstrates the usage of the NLP classification with a well-known dataset, which might be useful for users to understand how to implement it while getting a high 90% validation accuracy.
Would it be possible to add the example to the repository?
Added cls token/pooling option for NLP based full text classification