SynodicMonth / ChebyKAN

Kolmogorov-Arnold Networks (KAN) using Chebyshev polynomials instead of B-splines.
349 stars 36 forks source link

ChebyKAN layer is equivalent to custom activation + nn.Linear #3

Open JanRocketMan opened 6 months ago

JanRocketMan commented 6 months ago

Hi, very interesting idea, kudos!

I believe the proposed layer is equivalent to the following combination (I fix degree to be 4 for simplicity):


from ChebyKANLayer import ChebyKANLayer

class ChebyActivation(nn.Module):
    def __init__(self, degree):
        assert degree == 4
        super().__init__()

    def forward(self, x):
        x = torch.tanh(x)

        x = torch.cat(
            [
                torch.ones_like(x),
                x,
                2 * x**2 - 1,
                4 * x**3 - 3 * x,
                8 * x**4 - 8 * x**2 + 1,
            ],
            dim=1,
        )
        return x

input_dim = 128
output_dim = 256
for _ in range(100):
    variant_0 = ChebyKANLayer(input_dim, output_dim, 4)
    variant_1 = nn.Sequential(
        ChebyActivation(4),
        nn.Linear(input_dim * 5, output_dim, bias=False),
    )
    # ensure same weights
    variant_1[1].weight.data.copy_(variant_0.cheby_coeffs.permute(1, 2, 0).flatten(1))
    for _ in range(100):
        x = torch.randn(1234, input_dim)
        res1 = variant_0(x)
        res2 = variant_1(x)

        assert (
            res1 - res2
        ).abs().max() < 1e-6, "Found inconsistency between implementations!"

print("Two implementations are equivalent!")

This makes it a variant of LAN network (see app. B2 in KAN paper), which is nice, but it's a double-edged sword.

On one side, with this rewrite you can train it pretty efficiently (by checkpointing ChebyActivation function and using optimized cuda Linear kernel).

On the other side, modern networks like LLAMA3 already use Gated Linear Unit activations, which should give roughly equivalent representation (I'm not 100% sure on this point tho).

Do you think it's correct reasoning or maybe I'm missing smth?

Thanks in advance!

SynodicMonth commented 6 months ago

You're right.

  1. Yes. Fixed degrees can be optimized into polys. I kept the iteration in case of tuning degrees
  2. You nailed it. Its identical to LAN if we add activation before the linear and expand the input (degree + 1) times. So actually KAN = LAN and the learned function (on edge for KAN) can be represented as sum of functions (on nodes for LAN) if we expand the input several times? The original KAN uses $$w(silu⁢(x)+\sum_ic_iB_i(x))$$ is also expandable, right? (i might be missing thm though)
  3. I forgot about GLU, my bad. GLU probably have same effect. ill test it soon.

Really appreciate ur suggestion.

JanRocketMan commented 6 months ago

2 is a great question, as I understand this:

We have two separate choices - which nonlinearity to consider and whether we apply it on edges and then sum to nodes or directly on nodes.

If the nonlinearity is Cheby, then it doesn't matter whether we apply it on nodes or edges - we can always fuse rest of operations to a single nn.Linear.

If the nonlinearity is grid-based (like B-splines in KAN or smoothing splines), then in case of activation on edges we can't easily fuse computations in nn.Linear, because each edge will have different basis. In principle we can expand input to this larger basis set but it's gonna be very expensive (naively (degree + 1) * out_channels times). Maybe it's possible to fuse this expansion with subsequent Linear in a single op, but that would require writing custom CUDA kernels a-la Flash Attention.

On a more positive side, maybe we can share grid across different output channels (I believe efficient-kan is doing this) and it would be enough. But if we don't use grids at all I feel like it would't make any difference to GLUs. Maybe I'm wrong.

SynodicMonth commented 6 months ago

Indeed. cheby/fourier/legendre/hermite/laguerre are all the same here. They're all equivalent to custom activation + nn.Linear. So thats why KAN uses grid-based. Reaaaaaaaaally impressive.