lucidrains / tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
MIT License
808 stars 102 forks source link

Is there any training example about tabtransformer? #6

Open pancodex opened 3 years ago

pancodex commented 3 years ago

Hi, I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

Alexx776 commented 2 months ago

Hello, here is a simple example of training the model, I hope it can help

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tab_transformer_pytorch import TabTransformer
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

# Generate fake data
def generate_data(num_samples):
    x_categ = torch.randint(0, 10, (num_samples, 5))

    x_cont = torch.randn(num_samples, 10)

    y = torch.zeros(num_samples)
    for i in range(num_samples):
        if x_categ[i, 0] > 5 and x_cont[i, 0] > 0:
            y[i] = 2  
        elif x_categ[i, 1] < 3 or x_cont[i, 1] < -1:
            y[i] = 1 
        else:
            y[i] = 0  

    return x_categ, x_cont, y.long()

num_samples = 10000
x_categ, x_cont, y = generate_data(num_samples)

cont_mean = x_cont.mean(dim=0)
cont_std = x_cont.std(dim=0)
x_cont = (x_cont - cont_mean) / cont_std 
cont_mean_std = torch.stack([cont_mean, cont_std], dim=1)

# Model
model = TabTransformer(
    categories = (10, 10, 10, 10, 10), 
    num_continuous = 10,                
    dim = 64,                            
    dim_out = 3,                        
    depth = 6,                           
    heads = 8,                          
    attn_dropout = 0.1,                  
    ff_dropout = 0.1,                    
    mlp_hidden_mults = (4, 2),           
    mlp_act = nn.ReLU(),                 
    continuous_mean_std = cont_mean_std  
)

dataset = TensorDataset(x_categ, x_cont, y)
train_size = int(0.8 * len(dataset))  
test_size = len(dataset) - train_size 
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 

num_epochs = 20

# Train
for epoch in range(num_epochs):
    model.train()  
    total_loss = 0
    all_preds = []  
    all_labels = []  

    for batch_categ, batch_cont, batch_y in train_loader:

        optimizer.zero_grad()  

        outputs = model(batch_categ, batch_cont)
        loss = criterion(outputs, batch_y)  

        loss.backward() 
        optimizer.step()  

        total_loss += loss.item()  

        _, predicted = torch.max(outputs.data, 1) 
        all_preds.extend(predicted.numpy())  
        all_labels.extend(batch_y.numpy())  

    avg_loss = total_loss / len(train_loader)

    accuracy = accuracy_score(all_labels, all_preds)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")