AthanasiosDelis / faster-kan

Benchmarking and Testing FastKAN
Apache License 2.0
65 stars 8 forks source link

fast reverse mode with RSWAF basis #2

Open vpuri3 opened 6 months ago

vpuri3 commented 6 months ago

You can define a custom gradient rule for RSWAF basis by noting the derivative of the key operation

rswaf_core(x) = 1 - tanh(x)^2

has a lot of computation in common with the forward pass. Specifically,

rswaf_core_deriv(x) = -2 * tanh(x) * tanh_deriv(x)
tanh_deriv(x) = 1 - tanh(x)^2 # = rswaf_core(x)

A custom gradient can share work between the forward and backward pass thus improving efficiency and memory utilization. You can check my Julia implementation for reference.

AthanasiosDelis commented 6 months ago

It is a great idea indeed. I was thinking about it 2 days ago but I wanted to combine it with making the grid and the inverse denominator trainable parameters. This is my first time trying to write a custom backward and forward so any comment would be appreciated a lot.

My first experiment so far:

class RSWAFFunction(Function):
    @staticmethod
    def forward(ctx, input, grid, inv_denominator, train_grid, train_inv_denominator):
        # Compute the forward pass
        #print('\n')
        #print(f"Forward pass - grid: {(grid[0].item(),grid[-1].item())}, inv_denominator: {inv_denominator.item()}")

        #print(f"grid.shape: {grid.shape }")
        #print(f"grid: {(grid[0],grid[-1]) }")
        #print(f"inv_denominator.shape: {inv_denominator.shape }")
        #print(f"inv_denominator: {inv_denominator }")
        diff = (input[..., None] - grid)
        diff_mul = diff.mul(inv_denominator)
        tanh_diff = torch.tanh(diff)
        tanh_diff_deriviative = -tanh_diff.mul(tanh_diff) + 1  # sech^2(x) = 1 - tanh^2(x)

        # Save tensors for backward pass
        ctx.save_for_backward(input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator)
        ctx.train_grid = train_grid
        ctx.train_inv_denominator = train_inv_denominator

        return tanh_diff_deriviative

##### SOS NOT SURE HOW grad_inv_denominator, grad_grid ARE CALCULATED CORRECTLY YET
##### MUST CHECK https://github.com/pytorch/pytorch/issues/74802
##### MUST CHECK https://www.changjiangcai.com/studynotes/2020-10-18-Custom-Function-Extending-PyTorch/
##### MUST CHECK https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html
##### MUST CHECK https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
##### MUST CHECK https://gist.github.com/Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator = ctx.saved_tensors
        grad_grid = None
        grad_inv_denominator = None

        #print(f"tanh_diff_deriviative shape: {tanh_diff_deriviative.shape }")
        #print(f"tanh_diff shape: {tanh_diff.shape }")
        #print(f"grad_output shape: {grad_output.shape }")

        # Compute the backward pass for the input
        grad_input = -2 * tanh_diff * tanh_diff_deriviative * grad_output
        #print(f"Backward pass 1 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
        #print(f"grad_input shape: {grad_input.shape }")
        #print(f"grad_input.sum(dim=-1): {grad_input.sum(dim=-1).shape}")
        grad_input = grad_input.sum(dim=-1).mul(inv_denominator)
        #print(f"Backward pass 2 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
        #print(f"grad_input: {grad_input}")
        #print(f"grad_input shape: {grad_input.shape }")

        # Compute the backward pass for grid
        if ctx.train_grid:
            #print('\n')
            #print(f"grad_grid shape: {grad_grid.shape }")
            grad_grid = -inv_denominator * grad_output.sum(dim=0).sum(dim=0)#-(inv_denominator * grad_output * tanh_diff_deriviative).sum(dim=0) #-inv_denominator * grad_output.sum(dim=0).sum(dim=0)
            #print(f"Backward pass - grad_grid: {(grad_grid[0].item(),grad_grid[-1].item())}")
            #print(f"grad_grid.shape: {grad_grid.shape }")
            #print(f"grad_grid: {(grad_grid[0],grad_grid[-1]) }")
            #print(f"inv_denominator shape: {inv_denominator.shape }")
            #print(f"grad_grid shape: {grad_grid.shape }")

        # Compute the backward pass for inv_denominator        
        if ctx.train_inv_denominator:
            grad_inv_denominator = (grad_output* diff).sum() #(grad_output * diff * tanh_diff_deriviative).sum() #(grad_output* diff).sum() 
            #print(f"Backward pass - grad_inv_denominator: {grad_inv_denominator.item()}")
            #print(f"diff shape: {diff.shape }")

            #print(f"grad_inv_denominator shape: {grad_inv_denominator.shape }")
            #print(f"grad_inv_denominator : {grad_inv_denominator }")

        return grad_input, grad_grid, grad_inv_denominator, None, None # same number as tensors or parameters

class ReflectionalSwitchFunction(nn.Module):
    def __init__(
        self,
        grid_min: float = -1.2,
        grid_max: float = 0.2,
        num_grids: int = 8,
        exponent: int = 2,
        inv_denominator: float = 0.5,
        train_grid: bool = False,        
        train_inv_denominator: bool = False,
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.train_grid = torch.tensor(train_grid, dtype=torch.bool)
        self.train_inv_denominator = torch.tensor(train_inv_denominator, dtype=torch.bool) 
        self.grid = torch.nn.Parameter(grid, requires_grad=train_grid)
        #print(f"grid initial shape: {self.grid.shape }")
        self.inv_denominator = torch.nn.Parameter(torch.tensor(inv_denominator, dtype=torch.float32), requires_grad=train_inv_denominator)  # Cache the inverse of the denominator

    def forward(self, x):
        return RSWAFFunction.apply(x, self.grid, self.inv_denominator, self.train_grid, self.train_inv_denominator)

I am not sure yet how to handle grad_inv_denominator and grad_grid with respect to grad_output. Any idea or explantion of what I'm missing would be invaluable. So propably I will not have a lot of time for the project in the next 10 days due to PhD obligation, but I will try to keep in touch with any updates in the KAN ecosystem.

How is your experience playing with KANs in julia @vpuri3 ?

vpuri3 commented 6 months ago

Let torch handle gradients WRT grid and denominator. There's no speedup to be gained over there from what I understand. I would recommend only writing custom gradient for rswaf_core from my earlier comment.

So I would break it down as follows:

class RSWAF(nn.Module): # let torch handle gradients WRT grid, denominator
    def __init__(self, grid, denominator...): # 
        ...
        self.grid = grid               # trainable
        self.demoninator = denominator # trainable
    def forward(x):
         y = (x - self.grid) / self.denominator # dont pass grid, denominator to RSWAF_core
         return RSWAF_core(y)

class RSWAF_core(autograd.Function): # write custom gradient for this guy
    def forward(x, ...):
        ...
        return 1 - tanh(x)**2
    ...

I hope this helps.

My experience with Julia has been pretty smooth. It took me ~1 hr to come up with my implementation and it is only 2x slower than MLP.