KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.43k stars 1.32k forks source link

Any step-by-step guide on KAN with code? #376

Open apavlo89 opened 1 month ago

apavlo89 commented 1 month ago

Are there any step-by-step guides on how to apply KAN on tabular binary classification datasets??

apavlo89 commented 1 month ago

I used this guide https://www.kaggle.com/code/seyidcemkarakas/kan-tabular-data-binary-classification for my dataset and got this error:

checkpoint directory created: ./model
saving model version 0.0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-16-bff360054870>](https://localhost:8080/#) in <cell line: 41>()
     39 
     40 # KAN model training
---> 41 results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
     42                       metrics=(train_acc, test_acc),
     43                       opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss()) 

TypeError: Module.train() got an unexpected keyword argument 'metrics'`

`import torch
from kan import KAN

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Converting data to Torch tensor
train_input = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
train_label = torch.tensor(y_train.to_numpy(), dtype=torch.long)
val_input = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
val_label = torch.tensor(y_val.to_numpy(), dtype=torch.long)
test_input = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
test_label = torch.tensor(y_test.to_numpy(), dtype=torch.long)
dataset = {
    'train_input': train_input,
    'train_label': train_label,
    'val_input': val_input,
    'val_label': val_label,
    'test_input': test_input,
    'test_label': test_label
}

# Create model
model = KAN(width=[11, 2], grid=10, k=2)

# Fuctions for getting accuracy scores while training
def train_acc():
    preds = torch.argmax(model(dataset['train_input']), dim=1)
    return torch.mean((preds == dataset['train_label']).float())

def test_acc():
    preds = torch.argmax(model(dataset['test_input']), dim=1)
    return torch.mean((preds == dataset['test_label']).float())

# KAN model training
results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
                      metrics=(train_acc, test_acc),
                      opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss()) 

# Predictions of train val and test datasets
test_preds = torch.argmax(model.forward(test_input).detach(),dim=1)
test_labels = test_label

train_preds = torch.argmax(model.forward(train_input).detach(),dim=1)
train_labels = train_label

val_preds = torch.argmax(model.forward(val_input).detach(),dim=1)
val_labels = val_label

# Evaluate metrics

print("Train ACC:", accuracy_score(train_labels.numpy(), train_preds.numpy()))

print("Val ACC:", accuracy_score(val_labels.numpy(), val_preds.numpy()))

print("Test ACC:", accuracy_score(test_labels.numpy(), test_preds.numpy()))

# Plotting KAN network
model.plot(scale=10)

# Learning curve based on ACC and LOSS
plt.figure(figsize=(10, 5))
plt.plot(results["train_acc"], label='Training Accuracy')
plt.plot(results["test_acc"], label='Val Accuracy')
plt.plot(results["train_loss"], label='Training Loss')
plt.plot(results["test_loss"], label='Val Loss')
plt.title('Training and Val Accuracy over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Accuracy & Loss')
plt.legend()
plt.grid(True)
plt.show()`
seyidcemkarakas commented 1 month ago

Hi @apavlo89

I saw your comments on Kaggle on my notebook. Firstly thanks for looking.

Seconly, I run my code again but I dont get any errors. Can you share verisons of your libraries ?

zozi0406 commented 1 month ago

Hi @apavlo89,

Since 0.2.0, you will need to use the fit method instead of train.

seyidcemkarakas commented 1 month ago

Hi @apavlo89,

Since 0.2.0, you will need to use the fit method instead of train.

Yes, @apavlo89 Could you try it?

apavlo89 commented 1 month ago

fixed with this:

import torch
from kan import KAN

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Check the shape of your data
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of X_test: {X_test.shape}")

# Converting data to Torch tensor
train_input = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
train_label = torch.tensor(y_train.to_numpy(), dtype=torch.long)
val_input = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
val_label = torch.tensor(y_val.to_numpy(), dtype=torch.long)
test_input = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
test_label = torch.tensor(y_test.to_numpy(), dtype=torch.long)
dataset = {
    'train_input': train_input,
    'train_label': train_label,
    'val_input': val_input,
    'val_label': val_label,
    'test_input': test_input,
    'test_label': test_label
}

# Ensure the input width matches the number of features
input_width = X_train.shape[1]

# Create model
model = KAN(width=[input_width, 2], grid=10, k=2)

# Functions for getting accuracy scores while training
def train_acc():
    preds = torch.argmax(model(dataset['train_input']), dim=1)
    return torch.mean((preds == dataset['train_label']).float())

def test_acc():
    preds = torch.argmax(model(dataset['test_input']), dim=1)
    return torch.mean((preds == dataset['test_label']).float())

# KAN model training using the fit method
results = model.fit(dataset,
                    metrics=(train_acc, test_acc),
                    opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())

# Predictions of train, val, and test datasets
test_preds = torch.argmax(model.forward(test_input).detach(), dim=1)
test_labels = test_label

train_preds = torch.argmax(model.forward(train_input).detach(), dim=1)
train_labels = train_label

val_preds = torch.argmax(model.forward(val_input).detach(), dim=1)
val_labels = val_label

# Evaluate metrics
print("Train ACC:", accuracy_score(train_labels.numpy(), train_preds.numpy()))
print("Val ACC:", accuracy_score(val_labels.numpy(), val_preds.numpy()))
print("Test ACC:", accuracy_score(test_labels.numpy(), test_preds.numpy()))

# Plotting KAN network
model.plot(scale=10)

# Learning curve based on ACC and LOSS
plt.figure(figsize=(10, 5))
plt.plot(results["train_acc"], label='Training Accuracy')
plt.plot(results["val_acc"], label='Validation Accuracy')
plt.plot(results["train_loss"], label='Training Loss')
plt.plot(results["val_loss"], label='Validation Loss')
plt.title('Training and Validation Accuracy over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Accuracy & Loss')
plt.legend()
plt.grid(True)
plt.show()

Any idea how to tune model? I am getting absolutely abysmal performance and also get an error when all steps finish. Is it just a case of increasing steps? Anything else?:


 Shape of X_train: (1393, 369)
Shape of X_val: (725, 369)
Shape of X_test: (249, 369)
checkpoint directory created: ./model
saving model version 0.0

description:   0%|                                                          | 0/100 [00:00<?, ?it/s]
| train_loss: 1.92e+00 | test_loss: 1.86e+00 | reg: 0.00e+00 | :   0%|      | 0/100 [00:02<?, ?it/s]
| train_loss: 1.92e+00 | test_loss: 1.86e+00 | reg: 0.00e+00 | :   1%| | 1/100 [00:02<04:33,  2.76s/
| train_loss: 1.55e+00 | test_loss: 1.46e+00 | reg: 0.00e+00 | :   1%| | 1/100 [00:04<04:33,  2.76s/
| train_loss: 1.55e+00 | test_loss: 1.46e+00 | reg: 0.00e+00 | :   2%| | 2/100 [00:04<03:24,  2.09s/
| train_loss: 1.37e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | :   2%| | 2/100 [00:05<03:24,  2.09s/
| train_loss: 1.37e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | :   3%| | 3/100 [00:05<02:58,  1.84s/
| train_loss: 1.23e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | :   3%| | 3/100 [00:07<02:58,  1.84s/
| train_loss: 1.23e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | :   4%| | 4/100 [00:07<02:51,  1.79s/
| train_loss: 1.16e+00 | test_loss: 1.18e+00 | reg: 0.00e+00 | :   4%| | 4/100 [00:09<02:51,  1.79s/
| train_loss: 1.16e+00 | test_loss: 1.18e+00 | reg: 0.00e+00 | :   5%| | 5/100 [00:09<02:45,  1.74s/
| train_loss: 1.11e+00 | test_loss: 1.15e+00 | reg: 0.00e+00 | :   5%| | 5/100 [00:11<02:45,  1.74s/
| train_loss: 1.11e+00 | test_loss: 1.15e+00 | reg: 0.00e+00 | :   6%| | 6/100 [00:11<02:45,  1.76s/
| train_loss: 1.03e+00 | test_loss: 1.11e+00 | reg: 0.00e+00 | :   6%| | 6/100 [00:12<02:45,  1.76s/
| train_loss: 1.03e+00 | test_loss: 1.11e+00 | reg: 0.00e+00 | :   7%| | 7/100 [00:12<02:37,  1.69s/
| train_loss: 9.89e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :   7%| | 7/100 [00:14<02:37,  1.69s/
| train_loss: 9.89e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :   8%| | 8/100 [00:14<02:42,  1.77s/
| train_loss: 9.64e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :   8%| | 8/100 [00:16<02:42,  1.77s/
| train_loss: 9.64e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :   9%| | 9/100 [00:16<02:44,  1.80s/
| train_loss: 9.17e-01 | test_loss: 1.06e+00 | reg: 0.00e+00 | :   9%| | 9/100 [00:17<02:44,  1.80s/
| train_loss: 9.17e-01 | test_loss: 1.06e+00 | reg: 0.00e+00 | :  10%| | 10/100 [00:17<02:34,  1.71s
| train_loss: 1.04e+00 | test_loss: 1.17e+00 | reg: 0.00e+00 | :  10%| | 10/100 [00:19<02:34,  1.71s
| train_loss: 1.04e+00 | test_loss: 1.17e+00 | reg: 0.00e+00 | :  11%| | 11/100 [00:19<02:31,  1.70s
| train_loss: 9.94e-01 | test_loss: 1.14e+00 | reg: 0.00e+00 | :  11%| | 11/100 [00:21<02:31,  1.70s
| train_loss: 9.94e-01 | test_loss: 1.14e+00 | reg: 0.00e+00 | :  12%| | 12/100 [00:21<02:26,  1.67s
| train_loss: 9.54e-01 | test_loss: 1.10e+00 | reg: 0.00e+00 | :  12%| | 12/100 [00:22<02:26,  1.67s
| train_loss: 9.54e-01 | test_loss: 1.10e+00 | reg: 0.00e+00 | :  13%|▏| 13/100 [00:22<02:24,  1.66s
| train_loss: 9.22e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :  13%|▏| 13/100 [00:24<02:24,  1.66s
| train_loss: 9.22e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :  14%|▏| 14/100 [00:24<02:21,  1.65s
| train_loss: 8.94e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :  14%|▏| 14/100 [00:26<02:21,  1.65s
| train_loss: 8.94e-01 | test_loss: 1.07e+00 | reg: 0.00e+00 | :  15%|▏| 15/100 [00:26<02:26,  1.73s
| train_loss: 1.15e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | :  15%|▏| 15/100 [00:28<02:26,  1.73s
| train_loss: 1.15e+00 | test_loss: 1.28e+00 | reg: 0.00e+00 | :  16%|▏| 16/100 [00:28<02:43,  1.95s
| train_loss: 1.06e+00 | test_loss: 1.25e+00 | reg: 0.00e+00 | :  16%|▏| 16/100 [00:30<02:43,  1.95s
| train_loss: 1.06e+00 | test_loss: 1.25e+00 | reg: 0.00e+00 | :  17%|▏| 17/100 [00:30<02:34,  1.86s
| train_loss: 9.92e-01 | test_loss: 1.24e+00 | reg: 0.00e+00 | :  17%|▏| 17/100 [00:32<02:34,  1.86s
| train_loss: 9.92e-01 | test_loss: 1.24e+00 | reg: 0.00e+00 | :  18%|▏| 18/100 [00:32<02:27,  1.79s
| train_loss: 9.44e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  18%|▏| 18/100 [00:33<02:27,  1.79s
| train_loss: 9.44e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  19%|▏| 19/100 [00:33<02:21,  1.75s
| train_loss: 9.32e-01 | test_loss: 1.20e+00 | reg: 0.00e+00 | :  19%|▏| 19/100 [00:35<02:21,  1.75s
| train_loss: 9.32e-01 | test_loss: 1.20e+00 | reg: 0.00e+00 | :  20%|▏| 20/100 [00:35<02:16,  1.70s
| train_loss: 9.57e-01 | test_loss: 1.28e+00 | reg: 0.00e+00 | :  20%|▏| 20/100 [00:37<02:16,  1.70s
| train_loss: 9.57e-01 | test_loss: 1.28e+00 | reg: 0.00e+00 | :  21%|▏| 21/100 [00:37<02:18,  1.76s
| train_loss: 9.35e-01 | test_loss: 1.25e+00 | reg: 0.00e+00 | :  21%|▏| 21/100 [00:38<02:18,  1.76s
| train_loss: 9.35e-01 | test_loss: 1.25e+00 | reg: 0.00e+00 | :  22%|▏| 22/100 [00:38<02:16,  1.75s
| train_loss: 9.16e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  22%|▏| 22/100 [00:41<02:16,  1.75s
| train_loss: 9.16e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  23%|▏| 23/100 [00:41<02:23,  1.86s
| train_loss: 9.05e-01 | test_loss: 1.21e+00 | reg: 0.00e+00 | :  23%|▏| 23/100 [00:42<02:23,  1.86s
| train_loss: 9.05e-01 | test_loss: 1.21e+00 | reg: 0.00e+00 | :  24%|▏| 24/100 [00:42<02:15,  1.78s
| train_loss: 8.93e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  24%|▏| 24/100 [00:44<02:15,  1.78s
| train_loss: 8.93e-01 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  25%|▎| 25/100 [00:44<02:05,  1.68s
| train_loss: 1.60e+00 | test_loss: 1.54e+00 | reg: 0.00e+00 | :  25%|▎| 25/100 [00:45<02:05,  1.68s
| train_loss: 1.60e+00 | test_loss: 1.54e+00 | reg: 0.00e+00 | :  26%|▎| 26/100 [00:45<02:07,  1.73s
| train_loss: 1.36e+00 | test_loss: 1.26e+00 | reg: 0.00e+00 | :  26%|▎| 26/100 [00:47<02:07,  1.73s
| train_loss: 1.36e+00 | test_loss: 1.26e+00 | reg: 0.00e+00 | :  27%|▎| 27/100 [00:47<01:59,  1.64s
| train_loss: 1.35e+00 | test_loss: 1.24e+00 | reg: 0.00e+00 | :  27%|▎| 27/100 [00:48<01:59,  1.64s
| train_loss: 1.35e+00 | test_loss: 1.24e+00 | reg: 0.00e+00 | :  28%|▎| 28/100 [00:48<01:52,  1.56s
| train_loss: 1.32e+00 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  28%|▎| 28/100 [00:50<01:52,  1.56s
| train_loss: 1.32e+00 | test_loss: 1.22e+00 | reg: 0.00e+00 | :  29%|▎| 29/100 [00:50<01:52,  1.58s
| train_loss: 1.30e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | :  29%|▎| 29/100 [00:52<01:52,  1.58s
| train_loss: 1.30e+00 | test_loss: 1.21e+00 | reg: 0.00e+00 | :  30%|▎| 30/100 [00:52<01:55,  1.65s
| train_loss: 2.35e+00 | test_loss: 2.64e+00 | reg: 0.00e+00 | :  30%|▎| 30/100 [00:54<01:55,  1.65s
| train_loss: 2.35e+00 | test_loss: 2.64e+00 | reg: 0.00e+00 | :  31%|▎| 31/100 [00:54<02:07,  1.84s
| train_loss: 2.19e+00 | test_loss: 2.15e+00 | reg: 0.00e+00 | :  31%|▎| 31/100 [00:56<02:07,  1.84s
| train_loss: 2.19e+00 | test_loss: 2.15e+00 | reg: 0.00e+00 | :  32%|▎| 32/100 [00:56<02:00,  1.77s
| train_loss: 2.01e+00 | test_loss: 2.02e+00 | reg: 0.00e+00 | :  32%|▎| 32/100 [00:57<02:00,  1.77s
| train_loss: 2.01e+00 | test_loss: 2.02e+00 | reg: 0.00e+00 | :  33%|▎| 33/100 [00:57<01:54,  1.70s
| train_loss: 1.78e+00 | test_loss: 1.71e+00 | reg: 0.00e+00 | :  33%|▎| 33/100 [00:59<01:54,  1.70s
| train_loss: 1.78e+00 | test_loss: 1.71e+00 | reg: 0.00e+00 | :  34%|▎| 34/100 [00:59<01:51,  1.68s
| train_loss: 1.66e+00 | test_loss: 1.57e+00 | reg: 0.00e+00 | :  34%|▎| 34/100 [01:00<01:51,  1.68s
| train_loss: 1.66e+00 | test_loss: 1.57e+00 | reg: 0.00e+00 | :  35%|▎| 35/100 [01:00<01:48,  1.66s
| train_loss: 7.41e+00 | test_loss: 7.15e+00 | reg: 0.00e+00 | :  35%|▎| 35/100 [01:02<01:48,  1.66s
| train_loss: 7.41e+00 | test_loss: 7.15e+00 | reg: 0.00e+00 | :  36%|▎| 36/100 [01:02<01:51,  1.74s
| train_loss: 3.13e+00 | test_loss: 3.25e+00 | reg: 0.00e+00 | :  36%|▎| 36/100 [01:05<01:51,  1.74s
| train_loss: 3.13e+00 | test_loss: 3.25e+00 | reg: 0.00e+00 | :  37%|▎| 37/100 [01:05<01:57,  1.87s
| train_loss: 2.09e+00 | test_loss: 2.12e+00 | reg: 0.00e+00 | :  37%|▎| 37/100 [01:06<01:57,  1.87s
| train_loss: 2.09e+00 | test_loss: 2.12e+00 | reg: 0.00e+00 | :  38%|▍| 38/100 [01:06<01:57,  1.90s
| train_loss: 1.80e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | :  38%|▍| 38/100 [01:08<01:57,  1.90s
| train_loss: 1.80e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | :  39%|▍| 39/100 [01:08<01:48,  1.79s
| train_loss: 1.70e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | :  39%|▍| 39/100 [01:10<01:48,  1.79s
| train_loss: 1.70e+00 | test_loss: 1.81e+00 | reg: 0.00e+00 | :  40%|▍| 40/100 [01:10<01:44,  1.74s
| train_loss: 9.58e+00 | test_loss: 1.87e+00 | reg: 0.00e+00 | :  40%|▍| 40/100 [01:12<01:44,  1.74s
| train_loss: 9.58e+00 | test_loss: 1.87e+00 | reg: 0.00e+00 | :  41%|▍| 41/100 [01:12<01:46,  1.81s
| train_loss: 9.53e+00 | test_loss: 2.42e+00 | reg: 0.00e+00 | :  41%|▍| 41/100 [01:13<01:46,  1.81s
| train_loss: 9.53e+00 | test_loss: 2.42e+00 | reg: 0.00e+00 | :  42%|▍| 42/100 [01:13<01:42,  1.77s
| train_loss: 9.51e+00 | test_loss: 3.70e+00 | reg: 0.00e+00 | :  42%|▍| 42/100 [01:15<01:42,  1.77s
| train_loss: 9.51e+00 | test_loss: 3.70e+00 | reg: 0.00e+00 | :  43%|▍| 43/100 [01:15<01:35,  1.68s
| train_loss: 8.86e+00 | test_loss: 1.03e+01 | reg: 0.00e+00 | :  43%|▍| 43/100 [01:17<01:35,  1.68s
| train_loss: 8.86e+00 | test_loss: 1.03e+01 | reg: 0.00e+00 | :  44%|▍| 44/100 [01:17<01:37,  1.75s
| train_loss: 7.43e+00 | test_loss: 7.65e+00 | reg: 0.00e+00 | :  44%|▍| 44/100 [01:19<01:37,  1.75s
| train_loss: 7.43e+00 | test_loss: 7.65e+00 | reg: 0.00e+00 | :  45%|▍| 45/100 [01:19<01:42,  1.86s
| train_loss: 7.11e+00 | test_loss: 6.82e+00 | reg: 0.00e+00 | :  45%|▍| 45/100 [01:21<01:42,  1.86s
| train_loss: 7.11e+00 | test_loss: 6.82e+00 | reg: 0.00e+00 | :  46%|▍| 46/100 [01:21<01:41,  1.88s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  46%|▍| 46/100 [01:22<01:41,  1.88s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  47%|▍| 47/100 [01:22<01:34,  1.78s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  47%|▍| 47/100 [01:23<01:34,  1.78s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  48%|▍| 48/100 [01:23<01:19,  1.52s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  48%|▍| 48/100 [01:24<01:19,  1.52s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  49%|▍| 49/100 [01:24<01:07,  1.32s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  49%|▍| 49/100 [01:25<01:07,  1.32s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  50%|▌| 50/100 [01:25<00:57,  1.16s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  50%|▌| 50/100 [01:25<00:57,  1.16s
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  51%|▌| 51/100 [01:25<00:47,  1.03i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  51%|▌| 51/100 [01:26<00:47,  1.03i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  52%|▌| 52/100 [01:26<00:41,  1.17i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  52%|▌| 52/100 [01:27<00:41,  1.17i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  53%|▌| 53/100 [01:27<00:37,  1.26i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  53%|▌| 53/100 [01:27<00:37,  1.26i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  54%|▌| 54/100 [01:27<00:33,  1.36i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  54%|▌| 54/100 [01:28<00:33,  1.36i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  55%|▌| 55/100 [01:28<00:31,  1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  55%|▌| 55/100 [01:28<00:31,  1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  56%|▌| 56/100 [01:28<00:30,  1.43i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  56%|▌| 56/100 [01:29<00:30,  1.43i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  57%|▌| 57/100 [01:29<00:31,  1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  57%|▌| 57/100 [01:30<00:31,  1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  58%|▌| 58/100 [01:30<00:31,  1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  58%|▌| 58/100 [01:31<00:31,  1.35i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  59%|▌| 59/100 [01:31<00:29,  1.40i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  59%|▌| 59/100 [01:31<00:29,  1.40i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  60%|▌| 60/100 [01:31<00:26,  1.53i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  60%|▌| 60/100 [01:32<00:26,  1.53i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  61%|▌| 61/100 [01:32<00:24,  1.62i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  61%|▌| 61/100 [01:32<00:24,  1.62i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  62%|▌| 62/100 [01:32<00:23,  1.64i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  62%|▌| 62/100 [01:33<00:23,  1.64i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  63%|▋| 63/100 [01:33<00:21,  1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  63%|▋| 63/100 [01:33<00:21,  1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  64%|▋| 64/100 [01:33<00:20,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  64%|▋| 64/100 [01:34<00:20,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  65%|▋| 65/100 [01:34<00:19,  1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  65%|▋| 65/100 [01:35<00:19,  1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  66%|▋| 66/100 [01:35<00:19,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  66%|▋| 66/100 [01:35<00:19,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  67%|▋| 67/100 [01:35<00:18,  1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  67%|▋| 67/100 [01:36<00:18,  1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  68%|▋| 68/100 [01:36<00:18,  1.77i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  68%|▋| 68/100 [01:36<00:18,  1.77i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  69%|▋| 69/100 [01:36<00:17,  1.74i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  69%|▋| 69/100 [01:37<00:17,  1.74i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  70%|▋| 70/100 [01:37<00:17,  1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  70%|▋| 70/100 [01:37<00:17,  1.70i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  71%|▋| 71/100 [01:37<00:17,  1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  71%|▋| 71/100 [01:38<00:17,  1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  72%|▋| 72/100 [01:38<00:16,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  72%|▋| 72/100 [01:39<00:16,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  73%|▋| 73/100 [01:39<00:15,  1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  73%|▋| 73/100 [01:39<00:15,  1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  74%|▋| 74/100 [01:39<00:15,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  74%|▋| 74/100 [01:40<00:15,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  75%|▊| 75/100 [01:40<00:14,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  75%|▊| 75/100 [01:40<00:14,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  76%|▊| 76/100 [01:40<00:14,  1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  76%|▊| 76/100 [01:41<00:14,  1.68i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  77%|▊| 77/100 [01:41<00:14,  1.63i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  77%|▊| 77/100 [01:42<00:14,  1.63i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  78%|▊| 78/100 [01:42<00:14,  1.54i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  78%|▊| 78/100 [01:43<00:14,  1.54i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  79%|▊| 79/100 [01:43<00:14,  1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  79%|▊| 79/100 [01:43<00:14,  1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  80%|▊| 80/100 [01:43<00:13,  1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  80%|▊| 80/100 [01:44<00:13,  1.44i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  81%|▊| 81/100 [01:44<00:12,  1.49i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  81%|▊| 81/100 [01:44<00:12,  1.49i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  82%|▊| 82/100 [01:44<00:11,  1.55i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  82%|▊| 82/100 [01:45<00:11,  1.55i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  83%|▊| 83/100 [01:45<00:10,  1.58i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  83%|▊| 83/100 [01:46<00:10,  1.58i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  84%|▊| 84/100 [01:46<00:10,  1.56i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  84%|▊| 84/100 [01:46<00:10,  1.56i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  85%|▊| 85/100 [01:46<00:09,  1.65i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  85%|▊| 85/100 [01:47<00:09,  1.65i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  86%|▊| 86/100 [01:47<00:08,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  86%|▊| 86/100 [01:47<00:08,  1.73i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  87%|▊| 87/100 [01:47<00:07,  1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  87%|▊| 87/100 [01:48<00:07,  1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  88%|▉| 88/100 [01:48<00:06,  1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  88%|▉| 88/100 [01:48<00:06,  1.75i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  89%|▉| 89/100 [01:48<00:06,  1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  89%|▉| 89/100 [01:49<00:06,  1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  90%|▉| 90/100 [01:49<00:05,  1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  90%|▉| 90/100 [01:50<00:05,  1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  91%|▉| 91/100 [01:50<00:05,  1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  91%|▉| 91/100 [01:50<00:05,  1.80i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  92%|▉| 92/100 [01:50<00:04,  1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  92%|▉| 92/100 [01:51<00:04,  1.82i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  93%|▉| 93/100 [01:51<00:03,  1.85i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  93%|▉| 93/100 [01:51<00:03,  1.85i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  94%|▉| 94/100 [01:51<00:03,  1.79i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  94%|▉| 94/100 [01:52<00:03,  1.79i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  95%|▉| 95/100 [01:52<00:02,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  95%|▉| 95/100 [01:52<00:02,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  96%|▉| 96/100 [01:52<00:02,  1.72i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  96%|▉| 96/100 [01:53<00:02,  1.72i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  97%|▉| 97/100 [01:53<00:01,  1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  97%|▉| 97/100 [01:53<00:01,  1.78i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  98%|▉| 98/100 [01:53<00:01,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  98%|▉| 98/100 [01:54<00:01,  1.76i
| train_loss: 6.94e+00 | test_loss: 6.59e+00 | reg: 0.00e+00 | :  99%|▉| 99/100 [01:54<00:01,  1.16s
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-15-7af2f0c35dea>](https://localhost:8080/#) in <cell line: 48>()
     46 
     47 # KAN model training using the fit method
---> 48 results = model.fit(dataset,
     49                     metrics=(train_acc, test_acc),
     50                     opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())

[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in fit(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, metrics, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, singularity_avoiding, y_th, reg_metric, display_metrics)
    938             if _ == steps-1 and old_save_act:
    939                 #self.save_act = True
--> 940                 self.recover_save_act_in_fit()
    941 
    942             train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)

TypeError: MultKAN.recover_save_act_in_fit() missing 1 required positional argument: 'old_save_act'
apavlo89 commented 1 month ago

Also how would one go about doing feature selection with KAN?

seyidcemkarakas commented 1 month ago

Also how would one go about doing feature selection with KAN?

I have wrote code for feature selection, I will share it

apavlo89 commented 1 month ago

Exciting! Where will you post it?

KindXiaoming commented 1 month ago

an example of how to do feature selection using KANs: https://github.com/KindXiaoming/pykan/blob/master/tutorials/Interp_4_feature_attribution.ipynb

apavlo89 commented 1 month ago

It just doesn't like my dataset for some reason I've tried all kinds of things but i get these error - any idea why?:

from kan import *
import numpy as np

# Assuming X_train, y_train, X_val, y_val, X_test, y_test are already defined

# Combine training and validation sets
X_train_combined = np.concatenate((X_train, X_val), axis=0)
y_train_combined = np.concatenate((y_train, y_val), axis=0)

# Ensure the data is of type float
X_train_combined = X_train_combined.astype(float)
y_train_combined = y_train_combined.astype(float)
X_test = X_test.astype(float)
y_test = y_test.astype(float)

# Print shapes and types to debug
print(f"train_input shape: {X_train_combined.shape}, dtype: {X_train_combined.dtype}")
print(f"train_output shape: {y_train_combined.shape}, dtype: {y_train_combined.dtype}")
print(f"test_input shape: {X_test.shape}, dtype: {X_test.dtype}")
print(f"test_output shape: {y_test.shape}, dtype: {y_test.dtype}")

# Create the dataset in the expected format by mimicking the example
dataset = {
    'train_input': X_train_combined,
    'train_output': y_train_combined,
    'test_input': X_test,
    'test_output': y_test
}

# Ensure the structure matches the expected format
print(f"dataset keys: {dataset.keys()}")
for key in dataset:
    print(f"{key} shape: {dataset[key].shape}, dtype: {dataset[key].dtype}")

# Create and train the KAN model, same as the example
model = KAN(width=[X_train_combined.shape[1], 10, 10, 1], seed=2)
model.fit(dataset, steps=50, lamb=1e-3, reg_metric='edge_forward_n')

train_input shape: (2134, 3854), dtype: float64
train_output shape: (2134,), dtype: float64
test_input shape: (251, 3854), dtype: float64
test_output shape: (251,), dtype: float64
dataset keys: dict_keys(['train_input', 'train_output', 'test_input', 'test_output'])
train_input shape: (2134, 3854), dtype: float64
train_output shape: (2134,), dtype: float64
test_input shape: (251, 3854), dtype: float64
test_output shape: (251,), dtype: float64
checkpoint directory created: ./model
saving model version 0.0

description:   0%|                                                           | 0/50 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-39-8d89e215e034>](https://localhost:8080/#) in <cell line: 37>()
     35 # Create and train the KAN model, same as the example
     36 model = KAN(width=[X_train_combined.shape[1], 10, 10, 1], seed=2)
---> 37 model.fit(dataset, steps=50, lamb=1e-3, reg_metric='edge_forward_n')

3 frames
[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in fit(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, metrics, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, singularity_avoiding, y_th, reg_metric, display_metrics)
    944 
    945             if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
--> 946                 self.update_grid(dataset['train_input'][train_id])
    947 
    948             if opt == "LBFGS":

[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in update_grid(self, x)
    356 
    357     def update_grid(self, x):
--> 358         self.update_grid_from_samples(x)
    359 
    360     def initialize_grid_from_another_model(self, model, x):

[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in update_grid_from_samples(self, x)
    352     def update_grid_from_samples(self, x):
    353         for l in range(self.depth):
--> 354             self.get_act(x)
    355             self.act_fun[l].update_grid_from_samples(self.acts[l])
    356 

[/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py](https://localhost:8080/#) in get_act(self, x)
   1730         if isinstance(x, dict):
   1731             x = x['train_input']
-> 1732         if x == None:
   1733             if self.cache_data != None:
   1734                 x = self.cache_data

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
seyidcemkarakas commented 1 month ago

@apavlo89 Here is my code for feature importance:

# Modeli oluşturma
model = KAN(width=[4,7,1], grid=3, k=11)

# Modeli eğitme
results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
                      opt="LBFGS", steps=150, loss_fn=torch.nn.MSELoss())

# Validaionları verdik her ne kadar 'test_input': val_input yazsa da çünkü KAN model train dataset dict'ında test_input adlı bir dict istiyo

layer_1 = model.acts_scale[0].detach().numpy()
layer_2 = model.acts_scale[1].detach().numpy()

columns = list(X.columns)
importance_values = np.dot(layer_1.T, layer_2.T).flatten()

sorted_indices = np.argsort(importance_values)[::1]
sorted_importance_values = importance_values[sorted_indices]
sorted_columns = np.array(columns)[sorted_indices]

# Bar plot oluşturma
plt.figure(figsize=(12, 6))
plt.barh(sorted_columns, sorted_importance_values, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance of KAN')
plt.show()

You will see the barh plot

KindXiaoming commented 1 month ago

Hi @apavlo89 , looks like your data are numpy array, could you please try convert them to torch tensors?

apavlo89 commented 2 weeks ago

I fixed the numpy array to torch sensors but now i get this

@KindXiaoming @seyidcemkarakas

import torch
from kan import KAN

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Check the shape of your data
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of X_test: {X_test.shape}")

# Converting data to NumPy arrays
X_train_np = X_train.to_numpy()
y_train_np = y_train.to_numpy()
X_val_np = X_val.to_numpy()
y_val_np = y_val.to_numpy()
X_test_np = X_test.to_numpy()
y_test_np = y_test.to_numpy()

# Converting data to Torch tensor
train_input = torch.tensor(X_train_np, dtype=torch.float32)
train_label = torch.tensor(y_train_np, dtype=torch.long)
val_input = torch.tensor(X_val_np, dtype=torch.float32)
val_label = torch.tensor(y_val_np, dtype=torch.long)
test_input = torch.tensor(X_test_np, dtype=torch.float32)
test_label = torch.tensor(y_test_np, dtype=torch.long)

dataset = {
    'train_input': train_input,
    'train_label': train_label,
    'val_input': val_input,
    'val_label': val_label,
    'test_input': test_input,
    'test_label': test_label
}

# Ensure the input width matches the number of features
input_width = X_train_np.shape[1]

# Create model
model = KAN(width=[input_width, 2], grid=10, k=2)

# Workaround for the issue by setting save_act to False
model.save_act = False

# Functions for getting accuracy scores while training
def train_acc():
    preds = torch.argmax(model(dataset['train_input']), dim=1)
    return torch.mean((preds == dataset['train_label']).float())

def test_acc():
    preds = torch.argmax(model(dataset['test_input']), dim=1)
    return torch.mean((preds == dataset['test_label']).float())

# KAN model training using the fit method
results = model.fit({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
                    opt="LBFGS", steps=150, loss_fn=torch.nn.CrossEntropyLoss())

layer_1 = model.acts_scale[0].detach().numpy()
layer_2 = model.acts_scale[1].detach().numpy()

columns = list(X_train.columns)
importance_values = np.dot(layer_1.T, layer_2.T).flatten()

sorted_indices = np.argsort(importance_values)[::1]
sorted_importance_values = importance_values[sorted_indices]
sorted_columns = np.array(columns)[sorted_indices]

# Create the bar plot
plt.figure(figsize=(12, 6))
plt.barh(sorted_columns, sorted_importance_values, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance of KAN')
plt.show()

Output and error :/


Shape of X_train: (1886, 171)
Shape of X_val: (251, 171)
Shape of X_test: (253, 171)
checkpoint directory created: ./model
saving model version 0.0
| train_loss: nan | test_loss: nan | reg: 0.00e+00 | : 100%|██████| 150/150 [01:42<00:00,  1.46it/s]saving model version 0.1

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-12-674e38bca30b>](https://localhost:8080/#) in <cell line: 63>()
     61                     opt="LBFGS", steps=150, loss_fn=torch.nn.CrossEntropyLoss())
     62 
---> 63 layer_1 = model.acts_scale[0].detach().numpy()
     64 layer_2 = model.acts_scale[1].detach().numpy()
     65 

IndexError: list index out of range
KindXiaoming commented 1 week ago

try print(model.acts_scale), is it None? If yes, what is model.save_act (True or False)? If you try a dataset which does not return nan loss, does this problem persist?

apavlo89 commented 1 week ago

If i remove the feature selection code then it runs just fine. the issue appears in the feature selection part:


import torch
from kan import KAN

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Check the shape of your data
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of X_test: {X_test.shape}")

# Converting data to NumPy arrays
X_train_np = X_train.to_numpy()
y_train_np = y_train.to_numpy()
X_val_np = X_val.to_numpy()
y_val_np = y_val.to_numpy()
X_test_np = X_test.to_numpy()
y_test_np = y_test.to_numpy()

# Converting data to Torch tensor
train_input = torch.tensor(X_train_np, dtype=torch.float32)
train_label = torch.tensor(y_train_np, dtype=torch.long)
val_input = torch.tensor(X_val_np, dtype=torch.float32)
val_label = torch.tensor(y_val_np, dtype=torch.long)
test_input = torch.tensor(X_test_np, dtype=torch.float32)
test_label = torch.tensor(y_test_np, dtype=torch.long)

dataset = {
    'train_input': train_input,
    'train_label': train_label,
    'val_input': val_input,
    'val_label': val_label,
    'test_input': test_input,
    'test_label': test_label
}

# Ensure the input width matches the number of features
input_width = X_train_np.shape[1]

# Create model
model = KAN(width=[input_width, 2], grid=4, k=1)

# Workaround for the issue by setting save_act to False
model.save_act = False

# Functions for getting accuracy scores while training
def train_acc():
    preds = torch.argmax(model(dataset['train_input']), dim=1)
    return torch.mean((preds == dataset['train_label']).float())

def test_acc():
    preds = torch.argmax(model(dataset['test_input']), dim=1)
    return torch.mean((preds == dataset['test_label']).float())

# KAN model training using the fit method
results = model.fit({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
                    opt="LBFGS", steps=100, loss_fn=torch.nn.CrossEntropyLoss())
#opt can be Adam or LBFGS
print(model.acts_scale)
#####################################feature selection###########################################
layer_1 = model.acts_scale[0].detach().numpy()
layer_2 = model.acts_scale[1].detach().numpy()

columns = list(X_train.columns)
importance_values = np.dot(layer_1.T, layer_2.T).flatten()

sorted_indices = np.argsort(importance_values)[::1]
sorted_importance_values = importance_values[sorted_indices]
sorted_columns = np.array(columns)[sorted_indices]

## Create the bar plot
plt.figure(figsize=(12, 6))
plt.barh(sorted_columns, sorted_importance_values, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance of KAN')
plt.show()
##################################################################################################

# Predictions of train val and test datasets
test_preds = torch.argmax(model.forward(test_input).detach(),dim=1)
test_labels = test_label

train_preds = torch.argmax(model.forward(train_input).detach(),dim=1)
train_labels = train_label

val_preds = torch.argmax(model.forward(val_input).detach(),dim=1)
val_labels = val_label

# Evaluate metrics

print("Train ACC:", accuracy_score(train_labels.numpy(), train_preds.numpy()))

print("Val ACC:", accuracy_score(val_labels.numpy(), val_preds.numpy()))

print("Test ACC:", accuracy_score(test_labels.numpy(), test_preds.numpy()))

# Plotting KAN network
#model.plot(scale=input_width)
print(results)

# Learning curve based on ACC and LOSS
plt.figure(figsize=(10, 5))
plt.plot(results["train_acc"], label='Training Accuracy')
plt.plot(results["test_acc"], label='Val Accuracy')
plt.plot(results["train_loss"], label='Training Loss')
plt.plot(results["test_loss"], label='Val Loss')
plt.title('Training and Val Accuracy over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Accuracy & Loss')
plt.legend()
plt.grid(True)
plt.show()

`Shape of X_train: (1886, 171) Shape of X_val: (251, 171) Shape of X_test: (253, 171) checkpoint directory created: ./model saving model version 0.0 | train_loss: 6.57e-01 | test_loss: 9.09e-01 | reg: 0.00e+00 | : 100%|█| 100/100 [01:40<00:00, 1.01saving model version 0.1 []


IndexError Traceback (most recent call last) in <cell line: 65>() 63 print(model.acts_scale) 64 #####################################feature selection########################################### ---> 65 layer_1 = model.acts_scale[0].detach().numpy() 66 layer_2 = model.acts_scale[1].detach().numpy() 67

IndexError: list index out of range`