Kolmogorov–Arnold Transformer: |
Yes, I kan!
🎉 This is a PyTorch/GPU implementation of the paper Kolmogorov–Arnold Transformer (KAT), which replace the MLP layers in transformer with KAN layers.
Kolmogorov–Arnold Transformer
📝[Paper] </>[code] </>[CUDA kernel]
Xingyi Yang, Xinchao Wang
National University of Singapore
Vanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements.
Please find our CUDA implementation in https://github.com/Adamdad/rational_kat_cu.git.
# install torch and other things
pip install timm==1.0.3
pip install wandb # I personally use wandb for results visualizations
git clone https://github.com/Adamdad/rational_kat_cu.git
cd rational_kat_cu
pip install -e .
📦 Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Refer to example.py
for a detailed use case demonstrating how to use KAT with timm to classify an image.
Download pre-trained models or access training checkpoints:
🏷️ Model | ⚙️ Setup | 📦 Param | 📈 Top1 | 🔗 Link |
---|---|---|---|---|
KAT-T | From Scratch | 5.7M | 74.6 | link/huggingface |
KAT-T | From ViT | 5.7M | 75.7 | link/huggingface |
KAT-S | From Scratch | 22.1M | 81.2 | link/huggingface |
KAT-S | From ViT | 22.1M | 82.0 | link/huggingface |
KAT-B | From Scratch | 86.6M | 82.3 | link/huggingface |
KAT-B | From ViT | 86.6M | 82.8 | link/huggingface |
All training scripts are under scripts/
bash scripts/train_kat_tiny_8x128.sh
If you want to change the hyper-parameters, can edit
#!/bin/bash
DATA_PATH=/local_home/dataset/imagenet/
bash ./dist_train.sh 8 $DATA_PATH \
--model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions
-b 128 \
--opt adamw \
--lr 1e-3 \
--weight-decay 0.05 \
--epochs 300 \
--mixup 0.8 \
--cutmix 1.0 \
--sched cosine \
--smoothing 0.1 \
--drop-path 0.1 \
--aa rand-m9-mstd0.5 \
--remode pixel --reprob 0.25 \
--amp \
--crop-pct 0.875 \
--mean 0.485 0.456 0.406 \
--std 0.229 0.224 0.225 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/kat_tiny_swish_patch16_224 \
--log-wandb
To evaluate our kat_tiny_patch16_224
models, run:
DATA_PATH=/local_home/dataset/imagenet/
CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth
python validate.py $DATA_PATH --model kat_tiny_patch16_224 \
--checkpoint $CHECKPOINT_PATH -b 512
###################
Validating in float32. AMP not enabled.
Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'
Model kat_tiny_patch16_224 created, param count: 5718328
Data processing configuration for current model + dataset:
input_size: (3, 224, 224)
interpolation: bicubic
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)
crop_pct: 0.875
crop_mode: center
Test: [ 0/98] Time: 3.453s (3.453s, 148.28/s) Loss: 0.6989 (0.6989) Acc@1: 84.375 ( 84.375) Acc@5: 96.875 ( 96.875)
.......
Test: [ 90/98] Time: 0.212s (0.592s, 864.23/s) Loss: 1.1640 (1.1143) Acc@1: 71.875 ( 74.270) Acc@5: 93.750 ( 92.220)
* Acc@1 74.558 (25.442) Acc@5 92.390 (7.610)
--result
{
"model": "kat_tiny_patch16_224",
"top1": 74.558,
"top1_err": 25.442,
"top5": 92.39,
"top5_err": 7.61,
"param_count": 5.72,
"img_size": 224,
"crop_pct": 0.875,
"interpolation": "bicubic"
}
We extend our gratitude to the authors of rational_activations for their contributions to CUDA rational function implementations that inspired parts of this work. We thank @yuweihao, @florinshen, @Huage001 and @yu-rp for valuable discussions.
If you use this repository, please cite:
@misc{yang2024kat,
title={Kolmogorov–Arnold Transformer},
author={Xingyi Yang and Xinchao Wang},
year={2024},
eprint={2409.10594},
archivePrefix={arXiv},
primaryClass={cs.CV}
}