Open lingzhic opened 2 months ago
these numbers are complexities. one can choose their favorite numbers. For example, I think x^3 and x^4 are equally complex, so I assign complexity 3 to both functions. However, if you think x^4 is more complicated than x^3, you can assign 4 to x^4 and assign 3 to x^3.
these numbers are complexities. one can choose their favorite numbers. For example, I think x^3 and x^4 are equally complex, so I assign complexity 3 to both functions. However, if you think x^4 is more complicated than x^3, you can assign 4 to x^4 and assign 3 to x^3.
Thanks for the explanation!
these numbers are complexities. one can choose their favorite numbers. For example, I think x^3 and x^4 are equally complex, so I assign complexity 3 to both functions. However, if you think x^4 is more complicated than x^3, you can assign 4 to x^4 and assign 3 to x^3.
Hi Ziming,
I was trying to recover L-J potential: $V_{LJ}(r) = 4 \varepsilon [(\frac {\sigma} {r})^{12} - (\frac {\sigma} {r})^6]$
where r
is the distance between two points and \sigma
and \epsilon
are arbitrary positive floats. It seems a width=[1, 2, 1]
KAN is not able to recover the accurate formula though I tried to fix the symbolic form to x^6 and x^12 or 1/x^6 and 1/x^12 (and this is why I noticed the formula details in utils.py :)).
Could you please give any suggestions on this kind of formula?
It seems a width=[1, 2, 1] KAN is not able to recover the accurate formula though I tried to fix the symbolic form to x^6 and x^12 or 1/x^6 and 1/x^12
This sounds about right, but to make sure: fix the two activation functions (0,0,0) and (0,0,1) to be linear, (1,0,0) to 1/x^6 and (1,1,0) to 1/x^12, right? And probably you want fit_params=False.
I think this should work if done right. :)
fix the two activation functions (0,0,0) and (0,0,1) to be linear, (1,0,0) to 1/x^6 and (1,1,0) to 1/x^12, right? And probably you want fit_params=False.
Hi Ziming,
Thanks for your prompt reply.
Yes, I set fit_params = False
. But I didn't constrain all the nodes to be symbolic (I wonder if that will be identical to scipy.optimize.curve_fit).
To be more specific, here is my scratch code:
from kan import MultKAN as KAN
import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from kan.utils import create_dataset
eps = 1.0
sigma = 1
f = lambda x: eps*((sigma/x)**12 - (sigma/x)**6)
dataset = create_dataset(f, n_var=1, train_num=1000, test_num=1000, ranges=[0.8, 6], device=device)
dataset['train_input'].shape, dataset['train_label'].shape
xi = np.linspace(1, 6, 1000)
xi_gpu = torch.tensor(xi.reshape(1000, 1)).to('cuda')
plt.plot(xi, f(xi), label='true')
plt.scatter(dataset['train_input'].cpu(), dataset['train_label'].cpu(), label='train_set')
plt.scatter(dataset['test_input'].cpu(), dataset['test_label'].cpu(), label='test_set')
plt.legend()
plt.show()
# Add symbolic
from kan.utils import add_symbolic
f_inv6 = lambda x, y_th: ((x_th := 1/y_th**(1/6)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**6) * (torch.abs(x) >= x_th))
add_symbolic('1/x^6', lambda x: 1/x**6, c=6, fun_singularity=f_inv6)
f_inv12 = lambda x, y_th: ((x_th := 1/y_th**(1/12)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**12) * (torch.abs(x) >= x_th))
add_symbolic('1/x^12', lambda x: 1/x**12, c=12, fun_singularity=f_inv12)
model = KAN(width=[1, 2, 1], grid=50, k=12, seed=12345, device=device)
model.fix_symbolic(0,0,0,'1/x^12', fit_params_bool=False);
model.fix_symbolic(0,0,1,'1/x^6', fit_params_bool=False);
# model.fix_symbolic(1,0,0,'x', fit_params_bool=False);
# model.fix_symbolic(1,1,0,'x', fit_params_bool=False);
# train the model
history = model.fit(dataset, opt="LBFGS", steps=100, lamb=0.001, lamb_entropy=10, lr=0.1)
xi = np.linspace(0.8, 6, 1000)
xi_gpu = torch.tensor(xi.reshape(1000, 1)).to('cuda')
pred = model(xi_gpu)
plt.plot(xi, f(xi), label='true')
plt.plot(xi, pred.detach().cpu().numpy(), label='pred')
plt.legend()
plt.tight_layout()
plt.show()
Lines 31 and 32 in kan/utils.py have the wrong degree of order, it should be 4 and 5, respectively, instead of 3.