KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
15.12k stars 1.4k forks source link

Implementation seem to be not quite efficient #485

Open Hope1337 opened 1 month ago

Hope1337 commented 1 month ago

After reading the paper, I believe the KAN model should have $O(N^2L(G + k))$ parameters, which is only slightly larger than an MLP by a factor of (G + k) and should still be manageable in terms of size (in certain specific settings).

However, when I attempted to use model = KAN(width=[8,16,32,16,8,1], grid=15, k=3, seed=1, device=device) with an input batch size of 90, my 24GB VRAM GPU ran out of memory unexpectedly. I mean, yes, the KAN model requires more parameters than just (G + k), but with the above settings, it shouldn't exceed 24GB of VRAM, right? I suspect this might be due to the implementation, but the code is too complex to analyze in a short period of time. Has anyone here found anything related to the implementation causing such high VRAM usage? I would appreciate it if anyone could share their insights.