patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.06k stars 136 forks source link

Kolmogorov Arnold Networks with equinox [kanx] #730

Open stergiosba opened 4 months ago

stergiosba commented 4 months ago

There seems to be a race to find the best implementation (currently 4) of KANs on the PyTorch side of the ML universe but there is a dearth of efforts on the JAX side (only 1 with flax). The flax implementation seems, well... slow to be frank and not tested.

I had a go at and made a pure equinox based implementation that just works in 125 lines of code: KANX. I hope others find it interesting and use it.

It includes:

Just for reference, the fastest PyTorch implementation takes about 180 seconds for the same validation accuracy (have to perform more head to head comparisons though).

Let me know what you think. This effort can also be a documentation example for equinox.

patrick-kidger commented 4 months ago

Awesome! Let's see what the interest is like in KANs long-term -- I'd prefer not to add an example for every model du jour.

I can probably throw together a tweet on this though ;) (If you have a nice speed-vs-pytorch section in the README that'd be something I can clip out as a graphic?)

stergiosba commented 4 months ago

I like the term model du jour :)

I will do some tests vs PyTorch after work