KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.97k stars 1.38k forks source link

Trying a coef2curve example #273

Open drdozer opened 4 months ago

drdozer commented 4 months ago

Hi. Thanks for making the KAN code available. I've been trying to understand how it all works, and was trying to run coef2curve to get an idea of how the splines are calculated. However, I couldn't get it to run.

def exp_sin(x):
  return np.exp(np.sin(math.pi * x))

train_xs = np.arange(-5,5,0.25)
train_ys = exp_sin(train_xs)

t_xs = torch.reshape(torch.from_numpy(train_xs), (1,len(train_xs)))

num_spline = 1
num_sample = len(train_xs)
num_grid_interval = 10
k = 3
grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k))

print(f"Shape of t_xs: {t_xs.shape}")
print(f"Shape of grids: {grids.shape}")
print(f"Shape of coef: {coef.shape}")

t_ys = coef2curve(t_xs, grids, coef, k=k)

This barfs with an error related to shapes:

Shape of t_xs: torch.Size([1, 40])
Shape of grids: torch.Size([1, 11])
Shape of coef: torch.Size([1, 13])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-45-2b0233d6ec51>](https://localhost:8080/#) in <cell line: 16>()
     14 
     15 
---> 16 t_ys = coef2curve(t_xs, grids, coef, k=k)

1 frames
[<ipython-input-27-aae4645f44ba>](https://localhost:8080/#) in coef2curve(x_eval, grid, coef, k, device)
     90     if coef.dtype != x_eval.dtype:
     91         coef = coef.to(x_eval.dtype)
---> 92     y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
     93     return y_eval
     94 

[/usr/local/lib/python3.10/dist-packages/torch/functional.py](https://localhost:8080/#) in einsum(*args)
    383         # the path for contracting 0 or 1 time(s) is already optimized
    384         # or the user has disabled using opt_einsum
--> 385         return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    386 
    387     path = None

RuntimeError: einsum(): subscript j has size 7 for operand 1 which does not broadcast with previously seen size 13

I don't understand the code well enough to grasp what's going wrong here. It seems like some mismatch between the dimensionality of the grid and coefficients? I think I've stuck close to a copy-paste of your example, so unsure why it doesn't work.

ChrisZhangJin commented 2 weeks ago

This is a kind of dimension mismatching error, i fixed it in your code.

Here is something i also didn't get quite understand here

the parameters of B_batch

it wrote below, but actaully grid's shape is not what it says here. it looks like (number of samples, number of grid points), which seems a bit non sense.

x : 2D torch.tensor inputs, shape (number of splines, number of samples) grid : 2D torch.tensor grids, shape (number of splines, number of grid points)

now you have x with its shape (1,40), so the grid's shape cannot be (1,11). i changed it to shape (40,11)


Here is the code, i had run it without error now

import torch
import numpy as np
import math

from kan.spline import coef2curve

def exp_sin(x):
  return np.exp(np.sin(math.pi * x))

train_xs = np.arange(-5.,5.,0.25)
train_ys = exp_sin(train_xs)

t_xs = torch.reshape(torch.from_numpy(train_xs).float(), (1,len(train_xs)))

num_spline = 1
num_sample = len(train_xs)
num_grid_interval = 10
k = 3
grids = torch.einsum('i,j->ij', torch.ones(t_xs.shape[1],), torch.linspace(-1.,1.,steps=num_grid_interval+1))
coef = torch.normal(0.,1.,size=(grids.shape[0], num_grid_interval+k, num_grid_interval-k))

print(f"Shape of t_xs: {t_xs.shape}")
print(f"Shape of grids: {grids.shape}")
print(f"Shape of coef: {coef.shape}")

t_ys = coef2curve(t_xs, grids, coef, k=k)