SynodicMonth / ChebyKAN

Kolmogorov-Arnold Networks (KAN) using Chebyshev polynomials instead of B-splines.
343 stars 35 forks source link

ChebyKAN Having troubles on solving dynamical systems #4

Open mccrow2018 opened 4 months ago

mccrow2018 commented 4 months ago

Hi, ChebyKAN is indeed simple, elegant and powerful, I believe it can do more.

So I implemented it on solving some models with dynamical systems, economic models to be precise, where the dynamical systems or equations quite similar to PDEs.

The main problem I encountered is that ChebyKAN is more prone to be "stuck", preventing training going any further. Here are two illustrations with KAN structure

valuefunction_KAN  =  KAN(width=[2,5,5,1], grid=5, k=3, grid_eps=1.0, noise_scale_base=0.25)

and ChebyKAN structure

class ChebyKAN(nn.Module):
    def __init__(self):
        super(ChebyKAN, self).__init__()
        self.chebykan1 = ChebyKANLayer(2, 8, 8)
        self.chebykan2 = ChebyKANLayer(8, 16, 5)
        self.chebykan3 = ChebyKANLayer(16, 1, 5)

    def forward(self, x):
        x = self.chebykan1(x)
        x = self.chebykan2(x)
        x = self.chebykan3(x)
        return x

valuefunction_cheb  =  ChebyKAN()

Results on the first fig are trained by LBFGS, and by Adam with learning rate 1e-2 on the second fig.

aa3dee3b0229880fe6b2e1e9b2daa04 5556b5a73690251bc5af3c1a4524213

I have tested it multiple times, with different input, output dimensions and degree range from 4 to 12, the issue remains.

JanRocketMan commented 4 months ago

I believe that's expected due to the usage of tanh nonlinearity since Chebyshev polynomials are defined on [-1, 1] range.

Have you tried adding the residuals to the layers, or that's smth you wish to avoid? E.g. this way:

class ResidualChebyKAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.chebykan1 = ChebyKANLayer(2, 8, 8)
        self.chebykan2 = ChebyKANLayer(8, 16, 5)
        self.chebykan3 = ChebyKANLayer(16, 1, 5)

    def forward(self, x):
        x = self.chebykan1(x) + x
        x = self.chebykan2(x) + x
        x = self.chebykan3(x) + x
        return x

valuefunction_cheb  =  ResidualChebyKAN()
mccrow2018 commented 4 months ago

Thanks for the advice @JanRocketMan. I have tried this method and implemented it like the structure in paper, residuals in SiLU function plus approximations, I even added learnable scales. This did not solve the issue, but even be worse.

modified ChebyKAN:

import torch
import torch.nn as nn
import numpy as np

# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(ChebyKANLayer, self).__init__()
        self.inputdim      =  input_dim
        self.outdim        =  output_dim
        self.degree        =  degree

        self.cheby_coeffs  =  nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        self.scale_base    =  nn.Parameter(torch.ones(input_dim * output_dim))
        self.scale_cheb    =  nn.Parameter(torch.ones(input_dim * output_dim))

        nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1/(input_dim * (degree + 1)))

    def forward(self, x):
        x  =  torch.reshape(x, (-1, self.inputdim))  # shape = (batch_size, inputdim)

        xb = torch.einsum('bi,o->boi', x, torch.ones(self.outdim)).reshape(-1, self.inputdim * self.outdim)
        xb = (xb / (1 + torch.exp(-xb))) * self.scale_base.unsqueeze(dim = 0)
        xb = torch.sum(xb.reshape(x.shape[0], self.outdim, self.inputdim), dim = 2)

        x  =  torch.tanh(x) 

        # Initialize Chebyshev polynomial tensors
        cheby  =  torch.ones(x.shape[0], self.inputdim, self.degree + 1, device = x.device)

        if self.degree > 0:
            cheby[:, :, 1]  =  x
        for i in range(2, self.degree + 1):
            cheby[:, :, i]  =  2 * x * cheby[:, :, i - 1].clone() - cheby[:, :, i - 2].clone()
        # Compute the Chebyshev interpolation
        y  =  torch.einsum('bid,iod->bio', cheby, self.cheby_coeffs).reshape(-1, self.inputdim * self.outdim)
        y  =  y * self.scale_cheb.unsqueeze(dim = 0)
        y  =  torch.sum(y.reshape(x.shape[0], self.inputdim, self.outdim), dim = 1) 
        y  =  y.view(-1, self.outdim) + xb

        return y

and the results trained by LBFGS:

image

SynodicMonth commented 4 months ago

Its probably caused by tanh or/and the [-1, 1] domin. is it possible to provide the equation of the system or any code snippets? im excited to test on it.

mccrow2018 commented 4 months ago

It would be an honor. Here is a Colab link of the snippet: [KAN-sto-test].

SynodicMonth commented 4 months ago

I mede several tests and tunes and it seems ChebyKAN is truly a disaster on this problem. ChebyKAN [1,5,5,1] @ degree of 5 with @JanRocketMan 's advice indeed improves the performace in some way but its nowhere near the original KAN. I assume there're some fundamental flaws in the cheby. I think i need to test other polys and find a way to replace tanh. My apologies for the buggy implemention. ill be back when i found the solution.

Boris-73-TA commented 4 months ago

I used this alternative: class OrthogPolyKAN(nn.Module): def init(self): super(OrthogPolyKAN, self).init() self.orthogpolykan1 = OrthogPolyKANLayer(2, 5, 5, poly_type='Jacobi', alpha=3, beta=3) self.orthogpolykan2 = OrthogPolyKANLayer(5, 8, 4, poly_type='Cheby2') self.orthogpolykan3 = OrthogPolyKANLayer(8, 1, 3) # Legendre is default poly_type

def forward(self, x):
    x = self.orthogpolykan1(x)
    x = self.orthogpolykan2(x)
    x = self.orthogpolykan3(x)
    return x

valuefunction_op = OrthogPolyKAN() I get the following plot:

Screenshot 2024-05-14 at 21 07 25

Maybe including layers of different polynomials helps?

SynodicMonth commented 4 months ago

I used this alternative: class OrthogPolyKAN(nn.Module): def init(self): super(OrthogPolyKAN, self).init() self.orthogpolykan1 = OrthogPolyKANLayer(2, 5, 5, poly_type='Jacobi', alpha=3, beta=3) self.orthogpolykan2 = OrthogPolyKANLayer(5, 8, 4, poly_type='Cheby2') self.orthogpolykan3 = OrthogPolyKANLayer(8, 1, 3) # Legendre is default poly_type

def forward(self, x):
    x = self.orthogpolykan1(x)
    x = self.orthogpolykan2(x)
    x = self.orthogpolykan3(x)
    return x

valuefunction_op = OrthogPolyKAN() I get the following plot: Screenshot 2024-05-14 at 21 07 25 Maybe including layers of different polynomials helps?

That's a significant improvement. I indeed found that Chebyshev polynomials have limitations in the properties of the functions they can model. Additionally, I have another idea: it might be because ChebyKAN does not support continual learning (see #8). I'm not very familiar with these economic models, so I'm not sure if the input range will change if the model's prediction changes (I see many iterations from the same value function). I've been sick recently. I'll run more tests once I recover.