GistNoesis / FourierKAN

MIT License
712 stars 59 forks source link

Suggestion #6

Open Michael-Rempel opened 6 months ago

Michael-Rempel commented 6 months ago

One of the issues with Fourier is over sampling in the lower frequency domains. An alternative might be wavelets, which would handle over sampling in the low frequency area, and would give you the chance to ignore higher frequency noise. In fact since the KAN benefits from simplification you could start with a constrained sampling band and once that works well, experiment with widening the band for further training.

The original KAN spline solution is equally adaptable to wavelets. We could consider that to be long term memory. In addition I have often wondered if 'to the limit' excitation of small high frequency wavelets of a simple form might be a trigger to implement a higher larger wavelet as a learning step, and have an activation hierarchy to do full learning. Thus short term memory is much wider and more expensive than long term memory, but the form can be sent back and forth between the two memory forms as new information in context is suspected to be available. Finally from probability theory we know that discrete probabilities, which have no overlap are usable in probability calculus, where as fuzzy logic is not. Yet functionally to a large extent in a single layer they act in a very similar way. Therefore my intuition suggests that for layered, multi-factor inference we want the discrete variety of knowledge representation as much as possible in order to make the inferences cumulative without math issues. My final suggestion is to do pattern recognition against inference data to identify if we have a Pearlian causal inference. To those who know causal inference there is no need to explain. But if not, we can do a lot of optimizations when we know we are dealing with causal effects. For one thing it makes the formulation explanation of KAN more robust and automatically validatable.

Returning to wavelets for a moment: The choice of wavelet wave form can be the subject of a simple pre-processor that has a bias toward the media form being input. There is a lot of good literature on both audio and image processing. For the discrete logic I advocate above, the wavelet should be a single near square wave with the option to be either positive or negative depending on the local need.

I lack the mathematical skills to implement this idea but my extensive reading in the field gave me the intuition for the suggestions.

Regards, Michael Rempel

unrealwill commented 6 months ago

Thanks for your suggestions.

Fourier is indeed one common choice of basis for 1d function representation among a vast literature of representations.

Wavelet transforms, can be another one. Some other users have suggested Radial Basis Functions.

The best choice will probably be problem dependent.

Fourier, by their virtue of being continuous, are probably not a right choice to represent step functions.

Some choice of basis function will be more easy to fit (with existing deep learning machinery) than others. In particular, sparse like spline, or starting at t0 function like wavelets, may be prone to being to vanishing gradient as the basis function is flat by piece.

One important thing to note with KAN is that maintaining the full basis mean there is hope (guarantee?) of not being stuck. The inconvenient being that you may have plenty of zero coefficients, that cost computation but may allow you to escape local minima.

You mention Judea Pearl, father of do-calculus, and implicitly ask the question of whether this kind of model can make use of causality or discover it from data.

FourierKan natively don't. But you can stack them with other layers that do (fronting with Linear ?, CausalMasking ?). Or as you suggest, use a basis like wavelets, that has an order baked-in. One thing that is often surprising with MLPs, is that it can learn to read digits from MNIST data with shuffled pixels. Similarly with an image with shuffled patches, it won't infer the intermediate pixel representation. (wavelets won't do the trick for shuffled data).

One bitter lesson of deep learning is that usually complex baked-in priors don't pay off, and simplest is usually best, although when we have a lot of simple things it becomes complex and we lose interpret-ability.

These are all open research questions. I'll leave your suggestion open so that interested people could find it, and help explore the various speed-performance trade-off associated with specific choices of basis.

Michael-Rempel commented 6 months ago

I don’t know if I agree with your wavelet assessment. For images JPEG 2000 does a very fine job of sorting out fuzzy images to make them crisp.

The only purpose in choosing a wavelet function is that success depends on the shape for particular applications. No need to change the algorithm, just the wavelet in use.

This explains some of these ideas better than I can. https://www.youtube.com/watch?v=H_MJmR6IPNo Michael Unser: Wavelets and stochastic processes: how the Gaussian world became sparse youtube.com

On May 29, 2024, at 11:45 PM, unrealwill @.***> wrote:

Thanks for your suggestions.

Fourier is indeed one common choice of basis for 1d function representation among a vast literature of representations.

Wavelet transforms, can be another one. Some other users have suggested Radial Basis Functions.

The best choice will probably be problem dependent.

Fourier, by their virtue of being continuous, are probably not a right choice to represent step functions.

Some choice of basis function will be more easy to fit (with existing deep learning machinery) than others. In particular, sparse like spline, or starting at t0 function like wavelets, may be prone to being to vanishing gradient as the basis function is flat by piece.

One important thing to note with KAN is that maintaining the full basis mean there is hope (guarantee?) of not being stuck. The inconvenient being that you may have plenty of zero coefficients, that cost computation but may allow you to escape local minima.

You mention Judea Pearl, father of do-calculus, and implicitly ask the question of whether this kind of model can make use of causality or discover it from data.

FourierKan natively don't. But you can stack them with other layers that do (fronting with Linear ?, CausalMasking ?). Or as you suggest, use a basis like wavelets, that has an order baked-in. One thing that is often surprising with MLPs, is that it can learn to read digits from MNIST data with shuffled pixels. Similarly with an image with shuffled patches, it won't infer the intermediate pixel representation. (wavelets won't do the trick for shuffled data).

One bitter lesson of deep learning is that usually complex baked-in priors don't pay off, and simplest is usually best, although when we have a lot of simple things it becomes complex and we lose interpret-ability.

These are all open research questions. I'll leave your suggestion open so that interested people could find it, and help explore the various speed-performance trade-off associated with specific choices of basis.

— Reply to this email directly, view it on GitHub https://github.com/GistNoesis/FourierKAN/issues/6#issuecomment-2138796495, or unsubscribe https://github.com/notifications/unsubscribe-auth/BE3ITC64DNGMJHBFIGCSYC3ZE3DJZAVCNFSM6AAAAABIP762ZGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZYG44TMNBZGU. You are receiving this because you authored the thread.

unrealwill commented 6 months ago

Here are the modifications to transform FourierKan into HaarKan You probably want to do the custom fuse operations and/or use einsum for less memory usage.

The transformation is rather straightforward for other basis too. The code run but is untested. Normalizations constants should be checked. Inputs must be in the [-1,1] range.

Have fun experimenting with it :)

import torch as th
import numpy as np

# MIT LICENSED 

#Runs but Untested

#This is inspired by Kolmogorov-Arnold Networks but using Haar wavelets instead of splines coefficients
#Formulas from https://en.wikipedia.org/wiki/Haar_wavelet
#The wavelets basis are defined in [0,1] but we HaarKanLayer accept inputs in the [-1,1] range (by adding 1 and multiplying by 0.5) in the forward function
#You can use l2 normalization to make sure that the inputs are in correct range
class HaarKANLayer(th.nn.Module):
    def __init__( self, inputdim, outdim, nbwavelet, addbias=True):
        super(HaarKANLayer,self).__init__()
        self.nbwavelet= nbwavelet
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim

        Ns = []
        ks = []
        for n in range(nbwavelet):
          #TODO check the range for ks
          for k in range(1, pow(2,n)+1):
            Ns.append(n)
            ks.append(k)

        self.N = th.tensor(Ns)
        self.K = th.tensor(ks)

        #TODO: find the correct normalization factor
        self.haarcoeffs = th.nn.Parameter( th.randn(outdim,inputdim,len(Ns)) / np.sqrt(inputdim)  )
        if( self.addbias ):
            self.bias  = th.nn.Parameter( th.zeros(1,outdim))

    def psi(self,x):
        return th.zeros_like(x) + 1*(x>=0)*(x<0.5) - 1*(x>=0.5)*(x<1)

    def psink(self,x,n,k):
        #TODO : check the value of the normalization term
        return th.pow(2,n/2)*self.psi( th.pow(2,n)*x - k)

    #x.shape ( ... , indim ) 
    #out.shape ( ..., outdim)
    def forward(self,x):
        xshp = x.shape
        outshape = xshp[0:-1]+(self.outdim,)
        x = 0.5*( th.reshape(x,(-1,self.inputdim)) + 1)
        #We broadcast along the last dimension
        xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) ) 
        #This should be fused to avoid materializing memory
        c = self.psink( xrshp,self.N,self.K)
        #We compute the interpolation of the various functions defined by their haar-wavelet coefficient for each input coordinates and we sum them 
        y =  th.sum( c*self.haarcoeffs,(-2,-1)) 

        if( self.addbias):
            y += self.bias
        #End fuse
        '''
        #You can use einsum instead to reduce memory usage
        #Left to the reader
        '''
        y = th.reshape( y, outshape)
        return y

def demo():
    bs = 10
    L = 3 #Not necessary just to show that additional dimensions are batched like Linear
    inputdim = 400
    hidden = 200
    outdim = 100
    nbwavelet = 5

    device = "cpu" #"cuda"

    fkan1 = HaarKANLayer(inputdim, hidden, nbwavelet).to(device)
    fkan2 = HaarKANLayer(hidden, outdim, nbwavelet).to(device)

    x0 =th.randn(bs,inputdim).to(device)
    #l2 normalization
    x0n = th.nn.functional.normalize(x0)
    print( th.max(x0n) )
    print( th.min(x0n) )
    h = fkan1(x0n)
    print(h.shape)
    #l2 normalization
    hn = th.nn.functional.normalize(h)
    print( th.max(hn) )
    print( th.min(hn) )
    y = fkan2(hn)
    print( y.shape)

if __name__ == "__main__":
    demo()
Michael-Rempel commented 6 months ago

Hey thanks. I wasn’t certain if I was hacking irresponsibly if I did it myself.

On May 30, 2024, at 4:57 AM, unrealwill @.***> wrote:

Here are the modifications to transform FourierKan into HaarKan You probably want to do the custom fuse operations and/or use einsum for less memory usage.

The transformation is rather straightforward for other basis too. The code run but is untested. Normalizations constants should be checked. Inputs must be in the [-1,1] range.

Have fun experimenting with it :)

import torch as th import numpy as np

MIT LICENSED

Runs but Untested

This is inspired by Kolmogorov-Arnold Networks but using Haar wavelets instead of splines coefficients

Formulas from https://en.wikipedia.org/wiki/Haar_wavelet

The wavelets basis are defined in [0,1] but we HaarKanLayer accept inputs in the [-1,1] range (by adding 1 and multiplying by 0.5) in the forward function

You can use l2 normalization to make sure that the inputs are in correct range

class HaarKANLayer(th.nn.Module): def init( self, inputdim, outdim, nbwavelet, addbias=True): super(HaarKANLayer,self).init() self.nbwavelet= nbwavelet self.addbias = addbias self.inputdim = inputdim self.outdim = outdim

    Ns = []
    ks = []
    for n in range(nbwavelet):
      #TODO check the range for ks
      for k in range(1, pow(2,n)+1):
        Ns.append(n)
        ks.append(k)

    self.N = th.tensor(Ns)
    self.K = th.tensor(ks)

    #TODO: find the correct normalization factor
    self.haarcoeffs = th.nn.Parameter( th.randn(outdim,inputdim,len(Ns)) / np.sqrt(inputdim)  )
    if( self.addbias ):
        self.bias  = th.nn.Parameter( th.zeros(1,outdim))

def psi(self,x):
    return th.zeros_like(x) + 1*(x>=0)*(x<0.5) - 1*(x>=0.5)*(x<1)

def psink(self,x,n,k):
    #TODO : check the value of the normalization term
    return th.pow(2,n/2)*self.psi( th.pow(2,n)*x - k)

#x.shape ( ... , indim ) 
#out.shape ( ..., outdim)
def forward(self,x):
    xshp = x.shape
    outshape = xshp[0:-1]+(self.outdim,)
    x = 0.5*( th.reshape(x,(-1,self.inputdim)) + 1)
    #We broadcast along the last dimension
    xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) ) 
    #This should be fused to avoid materializing memory
    c = self.psink( xrshp,self.N,self.K)
    #We compute the interpolation of the various functions defined by their haar-wavelet coefficient for each input coordinates and we sum them 
    y =  th.sum( c*self.haarcoeffs,(-2,-1)) 

    if( self.addbias):
        y += self.bias
    #End fuse
    '''
    #You can use einsum instead to reduce memory usage
    #Left to the reader
    '''
    y = th.reshape( y, outshape)
    return y

def demo(): bs = 10 L = 3 #Not necessary just to show that additional dimensions are batched like Linear inputdim = 400 hidden = 200 outdim = 100 nbwavelet = 5

device = "cpu" #"cuda"

fkan1 = HaarKANLayer(inputdim, hidden, nbwavelet,smooth_initialization=True).to(device)
fkan2 = HaarKANLayer(hidden, outdim, nbwavelet,smooth_initialization=True).to(device)

x0 =th.randn(bs,inputdim).to(device)
#l2 normalization
x0n = th.nn.functional.normalize(x0)
print( th.max(x0n) )
print( th.min(x0n) )
h = fkan1(x0n)
print(h.shape)
#l2 normalization
hn = th.nn.functional.normalize(h)
print( th.max(hn) )
print( th.min(hn) )
y = fkan2(hn)
print( y.shape)

if name == "main": demo() — Reply to this email directly, view it on GitHub https://github.com/GistNoesis/FourierKAN/issues/6#issuecomment-2139397213, or unsubscribe https://github.com/notifications/unsubscribe-auth/BE3ITC6S2QIRHCYVZQ23FLTZE4H4PAVCNFSM6AAAAABIP762ZGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZZGM4TOMRRGM. You are receiving this because you authored the thread.