Blealtan / efficient-kan

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
MIT License
3.49k stars 306 forks source link

Supports inputs more than 2 dimensions For efficient-KAN #19

Closed WhatMelonGua closed 1 month ago

WhatMelonGua commented 1 month ago

I don't quite understand KAN's code, is it possiable for KanLinear to do as Torch.nn.Linear: only the last dimension is subjected to derivation operations, allowing inputs greater than 2 dimensions? For example, in multi head attention, our input is similar to [batch, nhead, dim]

However, this is not allowed in the current KAN ("assert x.dim() == 2 and x.size(1) == self.in_features") Excuse me! I am very interested in exploring the application of KAN in attention

minh-nguyenhoang commented 1 month ago

You could get around by flattening everything upto the final dimension, which in your case [batch*nhead, dim]

Liusir765832 commented 1 month ago

You could get around by flattening everything upto the final dimension, which in your case [batch*nhead, dim]

If my input size is (B, C, W, H), for example (8,128,64,64), how do I compress my input and how do I set the parameters of the KAN model.I look forward to your reply to my question. Thank you

WhatMelonGua commented 1 month ago

You could get around by flattening everything upto the final dimension, which in your case [batch*nhead, dim]

Thank you for your answer! I must admit that I am still not proficient in expanding the dimensions of torch, because I always worry about whether it will still be in its original position after being restored. Thank you very much, I think this is a very helpful answer!

This means that I can directly use "view (batch, -1)" for qkv point multiplication And then switch back from "view (batch, nhead, -1)" , right This operation has the same effect as "q.flat (dim=1)" and "q.continous().view (batch, nhead, -1)" There isn't any confusion among them, is there

Thank you very much. If that's the case, I think the issue should be closed immediately😳

WhatMelonGua commented 1 month ago

You could get around by flattening everything upto the final dimension, which in your case [batch*nhead, dim]

If my input size is (B, C, W, H), for example (8,128,64,64), how do I compress my input and how do I set the parameters of the KAN model.I look forward to your reply to my question. Thank you

It‘s correct! I tried in python use code below: x=torch.arange(0,16,1).reshape([2,4,2]) # you'll get it as [batch, nhead, dim] a = x.view(-1,2) a.view(-1, 4, 2) # you can see the data reshaped , if eq the original one, then the operation is correct!

so, try with your shape (B, C, W, H), maybe it can be influnced by your net design idea.

minh-nguyenhoang commented 1 month ago

You could get around by flattening everything upto the final dimension, which in your case [batch*nhead, dim]

If my input size is (B, C, W, H), for example (8,128,64,64), how do I compress my input and how do I set the parameters of the KAN model.I look forward to your reply to my question. Thank you

If you data has spatial structure, I would recommend using convolution version of KAN, which I believe there are some already implemented, for example convKAN or LeKAN (you can search for their github). If you have extracted the features, then flatten the H and W channels and permute them with the C channel, then you should have input of size [B, H*W, C] and then flatten first 2 dim to [B*H*W, C]. But if H and W is 64, I would suggest down sizing them more till like 16 or 8.