lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Question: problem with xval implementation #248

Closed HarshaSatyavardhan closed 3 months ago

HarshaSatyavardhan commented 3 months ago

I have tried xval with a simple mock example and tried to overfit and see what the model generates but its generate weird results

import torch
import json
from x_transformers import Decoder, XValTransformerWrapper, XValAutoregressiveWrapper

model = XValTransformerWrapper(
    num_tokens = 4,
    numerical_token_id = 3,
    max_seq_len = 1024,
    attn_layers = Decoder(dim = 512, depth = 12, heads = 8)
)
model = XValAutoregressiveWrapper(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, optimizer, epochs=10):
    # Constant mock data
    ids = torch.tensor([[1, 2, 3, 0, 0, 3, 3, 3, 2, 2, 0, 1, 0, 1, 2, 1]])
    nums = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.4426, 1]])
    mask = (nums != 1)

    for epoch in range(epochs):
        optimizer.zero_grad()
        loss = model(ids, nums, mask=mask)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

# Train the model for more epochs
train(model, optimizer, epochs=20)

Epoch 20/20, Loss: 0.0900447741150856

# then generate
start_ids = torch.randint(0, 4, (1, 1))
start_nums = torch.randn(1, 1)

ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 50)

# (1, 17), (1, 17), (1, 17)
ids_out, num_out, is_number_mask

results

(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0]]),
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan]]),
 tensor([[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False]]))

can you point out if I am doing anything wrong here.

lucidrains commented 3 months ago

@HarshaSatyavardhan hey Harsha, thanks for your interest. i somehow had the key padding mask included in the readme when it should not be there. the numerical mask is auto-handled based on the numerical_token_id

could you try running the script below? it should work

the nans you see are not an error, just to explicitly remind the researcher which values are not a number

import torch
from x_transformers import (
    Decoder,
    XValTransformerWrapper,
    XValAutoregressiveWrapper
)
from einops import repeat

model = XValTransformerWrapper(
    num_tokens = 4,
    numerical_token_id = 3,
    max_seq_len = 1024,
    attn_layers = Decoder(dim = 512, depth = 12, heads = 8)
)

model = XValAutoregressiveWrapper(model).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

def train(model, optimizer, epochs=10):
    # Constant mock data
    ids = torch.tensor([[1, 2, 3, 0, 3, 1]]).cuda()
    nums = torch.tensor([[0., 0., 3.14, 0., 2.72, 0.]]).cuda()

    batched_ids = repeat(ids, '1 n -> b n', b = 32)
    batched_nums = repeat(nums, '1 n -> b n', b = 32)

    for epoch in range(epochs):
        optimizer.zero_grad()
        loss = model(batched_ids, batched_nums)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

train(model, optimizer, epochs=50)

start_ids = torch.ones((1, 1)).cuda()
start_nums = torch.zeros(1, 1).cuda()

ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 5)

print(ids_out, num_out, is_number_mask)
lucidrains commented 3 months ago

@HarshaSatyavardhan there was an issue with the numerical loss 🤦‍♂️ should converge better now, even if it sort of worked before

HarshaSatyavardhan commented 2 months ago

@lucidrains which is correct for nums? using 1 or 0 where their is no number? nums = torch.tensor([[0., 0., 3.14, 0., 2.72, 0.]]).cuda() or this nums = torch.tensor([[1., 1., 3.14, 1., 2.72, 1.]]).cuda()

I think using 1 leads to better results is this correct or wrong ? according to the paper we are multiplying these values with the embeddings don't you think multiplying with zero leads to problem.

lucidrains commented 2 months ago

@HarshaSatyavardhan yes, you are correct! wow you understand the paper well

i protect against that here, so you can actually put any value there (even nans, to be explicit on what is not a number)

lucidrains commented 2 months ago

@HarshaSatyavardhan do let me know if/once you train anything significant with xval

if it works well for you, i may build out a special tokenizer to do this type of numerical encoding

HarshaSatyavardhan commented 5 days ago

@lucidrains can you build the numerical encoding special tokenizer that xval using. so that the encoding can be paired with TransformerWrapper rather than XValTransformerWrapper because for some cases their is no need to have a numerical head that predicts the number.