Blealtan / efficient-kan

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
MIT License
3.49k stars 306 forks source link

Difference in memory usage of mlp and kan #23

Open c-pupil opened 1 month ago

c-pupil commented 1 month ago

Hello: if input tensor size is [64,28x28],hidden layers is [256,256,256,256],The memory usage of mlp and kan is similar,382M and 500M respectively.The results are consistent with the experimental results: However,if the input tensor size is [36864,28x28],The memory usage of the two is huge different,844M and 14468M respectively.What is the reason for this?The initialization of the kan is consistent with that given in the example. And use a gpu.

Blealtan commented 1 month ago

The parameters are relevantly little in this case of [batch_size=36864, input_size], but the intermediate storage (taped variables/graph in terms of autograd) is huge. I think it's due to the b-spline computation creating too much intermediate variables. A huge fused kernel might help, but I don't have time working on this; needs some math work manually differentiating over the b-spline base functions.