Training small GPT-2 style models using KANs instead of MLPs in JAX
This repository compares transformers using multilayer perceptron (MLP) and Kolmogorov-Arnold networks (KAN) layers.
Key points:
- Uses Kolmogorov-Arnold Networks but with Chebyshev polynomials as the basis (inspired by this repo).
- The tanh function is used to keep the activation values within [-1, 1] rather than using grids that update during training.
- Both models are trained on 134M tokens of TinyStories.
- They both use standard GPT-2 architecture (other than the KAN part).
- The MLP version has 3.3M non-embedding weights and the KAN model has 2.5M non-embedding weights (~25% fewer).
Results:
They both achieve a final loss of ~2.46 (despite the KAN model having 25% fewer parameters!).
Hyperparameters:
d_model
: 128
d_mlp
: 768 (when applicable)
n_heads
: 8
n_layers
: 16
learning_rate
: 1e-5
batch_size
: 16
weight_decay
: 0.001
optimizer
: adamw
seq_len
: 64
Hardware: Single 1080ti GPU
Wandb: link.