lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

GPT training problem #162

Open phuvinhnguyen opened 1 year ago

phuvinhnguyen commented 1 year ago

I am currently having a problem while training the GPT1-like model. Even though I choose to train this model(~1.5M parameters) on a very small dataset(5 samples only). This model is unable to overfit the data.

This is the model I created:

model = TransformerWrapper(
    num_tokens = NUM_VOCAB,
    max_seq_len = MAX_SEQUENCE_LENGTH,
    attn_layers = Decoder(
        dim = EMBEDING_SIZE,
        depth = NUM_LAYER,
        heads = NUM_HEAD,
        alibi_pos_bias=True,
        use_rmsnorm=True,
        batch_first=True
    )
)

I tried to train this model on my own data with this function:

def train_model(model, train_loader: DataLoader):
    global min_loss
    model.to(DEVICE)
    criteria = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), LEARNING_RATE)

    for epoch in range(EPOCH):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs = inputs.to(DEVICE)
            targets = targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)

            important_index = torch.tensor([([i for i, val in enumerate(seq) if val == 0]+[MAX_SEQUENCE_LENGTH]*3)[2] for seq in targets]).max().item()

            loss = criteria(outputs[:,:important_index,:].permute(0,2,1), targets[:,:important_index])

            loss.backward()
            running_loss += loss.item()
            optimizer.step()

        epoch_loss = running_loss / len(train_loader)
        if min_loss > epoch_loss:
            torch.save(model.state_dict(), MODEL_PATH)
            min_loss = epoch_loss
            with open(LOG_FILE, 'w') as wf:
                wf.write(str(min_loss))
                wf.close()

        print(f"Epoch {epoch+1}/{EPOCH} - Loss: {epoch_loss}")

This is an example of input tensors:

tensor([[    0,    88,   159,    52,   861, 20093,  4063, 14365, 20868, 21621,
          2731, 21621, 10405, 12979,   130,     0],
        [    0,   171,   155,    52,   104, 20977,    76,   104,   209,    52,
           861,    10, 15781,   130, 11618,     0]]) 

and ground truth tensors:

tensor([[  0, 876,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0, 875,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0]])

After a few times of training, I think that the model tends to look for some start of each sample only and ignore the remaining ones. For example, in this tensor

[    0,    88,   159,    52,   861, 20093,  4063, 14365, 20868, 21621, 2731, 21621, 10405, 12979, 130,     0]

it seems like the model uses the first 2 numbers of the tensor to calculate(0 and 88 only), which makes it unable to classify the difference between 2 sentences starting the same way. Is there anything I don't know about and how can I get over this problem? Thanks in advance.