KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.9k stars 1.37k forks source link

Is KAN (pykan) sufficient for classification tasks? #477

Open SaranDS opened 2 weeks ago

SaranDS commented 2 weeks ago

The implemented following code snippet for binary classification on tabular data, using stratified K-fold cross-validation (K=10). The performance results seem exceptionally good. Can someone help review and suggest improvements to the implementation?

`model = KAN(width=[38,5,3, 2], grid=5, k=3) for train_idx, test_idx in (kf.split(X_scaled, y)):

X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=5) # Splitting train into train and val set

train_input = torch.tensor(X_train, dtype=torch.float32)
train_label = torch.tensor(y_train, dtype=torch.long)
val_input = torch.tensor(X_val, dtype=torch.float32)
val_label = torch.tensor(y_val, dtype=torch.long)
test_input = torch.tensor(X_test, dtype=torch.float32)
test_label = torch.tensor(y_test, dtype=torch.long)

dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
 }
 results = model.fit({'train_input': train_input, 'train_label': train_label, 
                     'test_input': val_input, 'test_label': val_label},
                     opt="LBFGS", steps=10, 
                    loss_fn=torch.nn.CrossEntropyLoss(),update_grid = False)

# Predictions 
test_preds = torch.argmax(model.forward(test_input).detach(), dim=1)

# Evaluate metrics on test set
PD, PF, auc, balance, fir, accuracy, precision = get_clf_eval(test_label, test_preds)

`

image

Dataset Description : Features - 39 data points - 16,900 (after SMOTE - 32,900)

YuriyKabanenko commented 3 days ago

@SaranDS Could you please provide more details on data you trained model on? Since i've been trying a lot to increase accuracy but maximum what i got is 78%. Firstly i tried BCEWithLogitsLoss for my binary classification task. After that i thought maybe there is magic accuracy accelerating because of CrossEntropy loss function. I refactored my model to output 2 features and take argMax as a result. But accuracy decreased to 74%.

Here code with 1 output:

`from kan import KAN

model = KAN(width=[16, 3, 1], grid=3, k=3)

def train_acc(): return torch.mean((torch.round(torch.sigmoid(model(custom_dataset['train_input']))[:,0]) == custom_dataset['train_label'][:,0]).type(dtype))

def test_acc(): return torch.mean((torch.round(torch.sigmoid(model(custom_dataset['test_input']))[:,0]) == custom_dataset['test_label'][:,0]).type(dtype))

start_time = time.time()

print(custom_dataset['train_input'].dtype) print(custom_dataset['test_input'].dtype)

results = model.fit(custom_dataset, opt="LBFGS", steps=100, batch=x_test.shape[0], metrics=(train_acc, test_acc), loss_fn=torch.nn.BCEWithLogitsLoss()) end_time = time.time()

results['train_acc'][-1], results['test_acc'][-1]`

Code with 2 output:

`from kan import KAN

model = KAN(width=[16, 5, 3, 2], grid=5, k=3)

start_time = time.time()

results = model.fit(custom_dataset, opt="LBFGS", steps=100, batch=x_test.shape[0], loss_fn=torch.nn.CrossEntropyLoss(), update_grid = False) end_time = time.time()`

SaranDS commented 1 day ago

@YuriyKabanenko The dataset utilized is from the software domain and pertains to a binary classification task. It comprises 38+1 features, with a total of 16,962 data points, all in numerical form. Prior to training, I applied pre-processing techniques, including MinMax scaling for normalization and SMOTE to address the class imbalance issue. To ensure generalizability, stratified K -fold [10-folds] cross-validation employed for splitting the dataset.

The results obtained for this dataset [not only this dataset, I used 3 other different dataset related to same domain, but each dataset comprise different features and datapoints] are exceptionally good, which raises concerns about the possibility of test set leakage during training. To verify this, I included print statements to check the sizes of the training, validation, and test sets. The output confirmed that the dataset was split correctly according to the specified sizes. However, I am unsure of any further methods to validate these results.

SaranDS commented 14 hours ago

The implemented following code snippet for binary classification on tabular data, using stratified K-fold cross-validation (K=10). The performance results seem exceptionally good. Can someone help review and suggest improvements to the implementation?

`model = KAN(width=[38,5,3, 2], grid=5, k=3) for train_idx, test_idx in (kf.split(X_scaled, y)):

X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=5) # Splitting train into train and val set

train_input = torch.tensor(X_train, dtype=torch.float32)
train_label = torch.tensor(y_train, dtype=torch.long)
val_input = torch.tensor(X_val, dtype=torch.float32)
val_label = torch.tensor(y_val, dtype=torch.long)
test_input = torch.tensor(X_test, dtype=torch.float32)
test_label = torch.tensor(y_test, dtype=torch.long)

dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
 }
 results = model.fit({'train_input': train_input, 'train_label': train_label, 
                     'test_input': val_input, 'test_label': val_label},
                     opt="LBFGS", steps=10, 
                    loss_fn=torch.nn.CrossEntropyLoss(),update_grid = False)

# Predictions 
test_preds = torch.argmax(model.forward(test_input).detach(), dim=1)

# Evaluate metrics on test set
PD, PF, auc, balance, fir, accuracy, precision = get_clf_eval(test_label, test_preds)

`

image

Dataset Description : Features - 39 data points - 16,900 (after SMOTE - 32,900)

@KindXiaoming Is the implemented code a valid method for predicting test samples on tabular data utilizing KAN?