KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.94k stars 1.38k forks source link

Out of sample generalisation #19

Closed CrispyCrafter closed 5 months ago

CrispyCrafter commented 5 months ago

What is the expected stability of these methods for out of sample predictions? I ran a simple example using a power curve function ( see below ) limited between [1,100]. The model performs quite well on the training set. Out of set it quickly diverges.

from kan import *

model = KAN(width=[1,2,1], grid=3, k=2, symbolic_enabled=False)
a,b,c = 0.7, 0.2, 0.1
f = lambda x: a*x**-b + c
dataset = create_dataset(f, n_var=1, ranges=[1,100])
model.train(dataset, opt="LBFGS", steps=20);

x = torch.tensor(np.arange(1,1000,0.1)).reshape(-1,1)[1:]
y_actual = f(x)
y_pred = model(x)[:,0].detach().numpy()

plt.scatter(x.numpy(),y_actual, color="red")
plt.scatter(x.numpy(),y_pred, color="blue")

image

KindXiaoming commented 5 months ago

I'm afraid that KANs don't guarantee out-of-distribution generalization (just like other models by default), unless you try symbolically snapping them.

CrispyCrafter commented 5 months ago

Thanks for the feedback! Can you elaborate on symbolic-snapping?

CrispyCrafter commented 5 months ago

I guess the key idea here is to use symbolic regression to extract a parametric function that describes the data. From there you should be able to generalise

KindXiaoming commented 5 months ago

Here is an example where KANs can generalize out-of-distribution. Frankly any symbolic method can achieve this if the correct symbolic formula is found. I cheated a little bit by changing a few things:

from kan import *

model = KAN(width=[1,1], grid=3, k=2)
a,b,c = 1, 1, 0.1
f = lambda x: a*x**-b + c
dataset = create_dataset(f, n_var=1, ranges=[1,2])
model.train(dataset, opt="LBFGS", steps=100);

# fix function to be 1/x (unfortunately we don't support fractional order yet)
model.fix_symbolic(0,0,0,'1/x')

x = torch.tensor(np.arange(1,100,0.1)).reshape(-1,1)[1:]
y_actual = f(x)
y_pred = model(x)[:,0].detach().numpy()

plt.scatter(x.numpy(),y_actual, color="red")
plt.scatter(x.numpy(),y_pred, color="blue")
Screenshot 2024-05-02 at 01 01 27
CrispyCrafter commented 5 months ago

Amazing! perhaps going the sympy route might make sense? See for instance the PySR implementation https://github.com/MilesCranmer/PySR/blob/116eee19568e9cc0801e8d5996771ab75bb95998/pysr/export_sympy.py#L8