KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.98k stars 1.38k forks source link

How to make OOT prediction ? #403

Open seyidcemkarakas opened 2 months ago

seyidcemkarakas commented 2 months ago

Hi

I have asked several questions about this topic.

When model training is done:

# Create KAN
model = KAN(width=[len(X.columns),1], grid=9, k=3)

# Train KAN
results = model.train({'train_input': train_input, 'train_label': train_label, 'test_input': val_input, 'test_label': val_label},
                      opt="LBFGS", 
                      steps=100,
                      loss_fn=torch.nn.MSELoss()) 

I can get model.plot()

model.plot(scale=1.3)

image

and how can I make OOT prediction (or prediction of 3th part data)

I have tried using model.forward() and I got outputs:

test_preds= model.forward(test_input).detach()
test_preds= test_label

But then when I check model.plot() it changes. Why ? Is there any method that doesnt change model arthitecture and still make OOT predictions?

seyidcemkarakas commented 2 months ago

@KindXiaoming Did you see this question?

KindXiaoming commented 2 months ago
  1. plot() visualizations depend on input data. That's why you see two different plots for train/test data.
  2. KAN, like other foundation models, cannot handle OOD extrapolation by default. More Inductive biases are required.