Open stergiosba opened 4 months ago
Awesome! Let's see what the interest is like in KANs long-term -- I'd prefer not to add an example for every model du jour.
I can probably throw together a tweet on this though ;) (If you have a nice speed-vs-pytorch section in the README that'd be something I can clip out as a graphic?)
I like the term model du jour :)
I will do some tests vs PyTorch after work
There seems to be a race to find the best implementation (currently 4) of KANs on the PyTorch side of the ML universe but there is a dearth of efforts on the JAX side (only 1 with flax). The flax implementation seems, well... slow to be frank and not tested.
I had a go at and made a pure
equinox
based implementation that just works in 125 lines of code: KANX. I hope others find it interesting and use it.It includes:
KANLayer
that can be stacked inside aneqx.nn.Sequential
as any other module. Tested with an MLP and it works (further testing needed to iron this but it should work)equinox
website to test and report96.7%
accuracy on validation set with minimal tuning and it took 13 seconds.Just for reference, the fastest PyTorch implementation takes about 180 seconds for the same validation accuracy (have to perform more head to head comparisons though).
Let me know what you think. This effort can also be a documentation example for
equinox
.