This repository contains a very fast implementation of the Kolmogorov-Arnold Network (KAN), by replacing the 3-order B-spline basis in the original KANs with Radial Basis Functions (RBFs).
The forward time of FaskKAN is 3.33x faster than efficient KAN, and the implementation is a LOT easier.
The original implementation of KAN is pykan.
One can install fast-kan via pip:
git clone https://github.com/ZiyaoLi/fast-kan
cd fast-kan
pip install .
Run an example training of the FastKAN network on MNIST:
python examples/train_mnist.py
$$b_{i}(u)=\exp\left(-\left(\frac{u-u_i}{h}\right)^2\right)$$
The rationale for doing so is that these RBF functions well approximate the B-spline basis (up to a linear transformation) and are very easy to calculate (as long as the grids are uniform). Results are shown in the figure below (code in notebook).
Uses LayerNorm to scale inputs to the range of spline grids, so there is no need to adjust the grids.
FastKAN is 3.33x compared with efficient_kan in forward speed. (see notebook, 742us -> 223us on V100)
Accuracy on MNIST is equivalent / slightly improved.
FastKANLayer supports users in plotting the learned curves dim-by-dim. See notebook for an example of usage.
Copyright 2024 Li, Ziyao. Licensed under the Apache License, Version 2.0.
@article{li2024kolmogorovarnold,
title={Kolmogorov-Arnold Networks are Radial Basis Function Networks},
author={Ziyao Li},
year={2024},
eprint={2405.06721},
archivePrefix={arXiv},
primaryClass={cs.LG}
}